diff --git a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat deleted file mode 100755 index 94f5d1023..000000000 --- a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat +++ /dev/null @@ -1,3 +0,0 @@ -..\python_embeded\python.exe .\update.py ..\ComfyUI\ -..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -pause diff --git a/.ci/nightly/windows_base_files/run_nvidia_gpu.bat b/.ci/nightly/windows_base_files/run_nvidia_gpu.bat deleted file mode 100755 index 8ee2f3402..000000000 --- a/.ci/nightly/windows_base_files/run_nvidia_gpu.bat +++ /dev/null @@ -1,2 +0,0 @@ -.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --use-pytorch-cross-attention -pause diff --git a/.github/workflows/test-ui.yaml b/.github/workflows/test-ui.yaml index 950691755..4b8b97934 100644 --- a/.github/workflows/test-ui.yaml +++ b/.github/workflows/test-ui.yaml @@ -22,5 +22,5 @@ jobs: run: | npm ci npm run test:generate - npm test + npm test -- --verbose working-directory: ./tests-ui diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index b793f7fe2..90e09d27a 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -2,6 +2,24 @@ name: "Windows Release Nightly pytorch" on: workflow_dispatch: + inputs: + cu: + description: 'cuda version' + required: true + type: string + default: "121" + + python_minor: + description: 'python minor version' + required: true + type: string + default: "12" + + python_patch: + description: 'python patch version' + required: true + type: string + default: "1" # push: # branches: # - master @@ -20,21 +38,21 @@ jobs: persist-credentials: false - uses: actions/setup-python@v4 with: - python-version: '3.11.6' + python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }} - shell: bash run: | cd .. cp -r ComfyUI ComfyUI_copy - curl https://www.python.org/ftp/python/3.11.6/python-3.11.6-embed-amd64.zip -o python_embeded.zip + curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip unzip python_embeded.zip -d python_embeded cd python_embeded - echo 'import site' >> ./python311._pth + echo 'import site' >> ./python3${{ inputs.python_minor }}._pth curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py - python -m pip wheel torch torchvision torchaudio aiohttp==3.8.5 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir + python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir ls ../temp_wheel_dir ./python.exe -s -m pip install --pre ../temp_wheel_dir/* - sed -i '1i../ComfyUI' ./python311._pth + sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth cd .. git clone https://github.com/comfyanonymous/taesd @@ -49,9 +67,10 @@ jobs: mkdir update cp -r ComfyUI/.ci/update_windows/* ./update/ cp -r ComfyUI/.ci/windows_base_files/* ./ - cp -r ComfyUI/.ci/nightly/update_windows/* ./update/ - cp -r ComfyUI/.ci/nightly/windows_base_files/* ./ + echo "..\python_embeded\python.exe .\update.py ..\ComfyUI\\ + ..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 + pause" > ./update/update_comfyui_and_python_dependencies.bat cd .. "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 202121e10..000000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "path-intellisense.mappings": { - "../": "${workspaceFolder}/web/extensions/core" - }, - "[python]": { - "editor.defaultFormatter": "ms-python.autopep8" - }, - "python.formatting.provider": "none" -} diff --git a/__init__.py b/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 625fae33f..ff1ca98c9 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -53,7 +53,7 @@ class ControlNet(nn.Module): transformer_depth_middle=None, transformer_depth_output=None, device=None, - operations=ops, + operations=ops.disable_weight_init, **kwargs, ): super().__init__() @@ -141,24 +141,24 @@ class ControlNet(nn.Module): ) ] ) - self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)]) + self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)]) self.input_hint_block = TimestepEmbedSequential( - operations.conv_nd(dims, hint_channels, 16, 3, padding=1), + operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 16, 16, 3, padding=1), + operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2), + operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 32, 32, 3, padding=1), + operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2), + operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 96, 96, 3, padding=1), + operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2), + operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device), nn.SiLU(), - zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1)) + operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device) ) self._feature_size = model_channels @@ -206,7 +206,7 @@ class ControlNet(nn.Module): ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) - self.zero_convs.append(self.make_zero_conv(ch, operations=operations)) + self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: @@ -234,7 +234,7 @@ class ControlNet(nn.Module): ) ch = out_ch input_block_chans.append(ch) - self.zero_convs.append(self.make_zero_conv(ch, operations=operations)) + self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)) ds *= 2 self._feature_size += ch @@ -276,14 +276,14 @@ class ControlNet(nn.Module): operations=operations )] self.middle_block = TimestepEmbedSequential(*mid_block) - self.middle_block_out = self.make_zero_conv(ch, operations=operations) + self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device) self._feature_size += ch - def make_zero_conv(self, channels, operations=None): - return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0))) + def make_zero_conv(self, channels, operations=None, dtype=None, device=None): + return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device)) def forward(self, x, hint, timesteps, context, y=None, **kwargs): - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) emb = self.time_embed(t_emb) guided_hint = self.input_hint_block(hint, emb, context) @@ -295,7 +295,7 @@ class ControlNet(nn.Module): assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) - h = x.type(self.dtype) + h = x for module, zero_conv in zip(self.input_blocks, self.zero_convs): if guided_hint is not None: h = module(h, emb, context) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 58e74f62f..a9d23b446 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -55,13 +55,19 @@ fp_group = parser.add_mutually_exclusive_group() fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") -parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") +fpunet_group = parser.add_mutually_exclusive_group() +fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") +fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.") +fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.") +fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.") fpvae_group = parser.add_mutually_exclusive_group() fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.") +parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.") + fpte_group = parser.add_mutually_exclusive_group() fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).") fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).") @@ -98,7 +104,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") - +parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") diff --git a/comfy/clip_model.py b/comfy/clip_model.py new file mode 100644 index 000000000..7397b7a26 --- /dev/null +++ b/comfy/clip_model.py @@ -0,0 +1,188 @@ +import torch +from comfy.ldm.modules.attention import optimized_attention_for_device + +class CLIPAttention(torch.nn.Module): + def __init__(self, embed_dim, heads, dtype, device, operations): + super().__init__() + + self.heads = heads + self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + + self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + + def forward(self, x, mask=None, optimized_attention=None): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + out = optimized_attention(q, k, v, self.heads, mask) + return self.out_proj(out) + +ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), + "gelu": torch.nn.functional.gelu, +} + +class CLIPMLP(torch.nn.Module): + def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations): + super().__init__() + self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device) + self.activation = ACTIVATIONS[activation] + self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device) + + def forward(self, x): + x = self.fc1(x) + x = self.activation(x) + x = self.fc2(x) + return x + +class CLIPLayer(torch.nn.Module): + def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations): + super().__init__() + self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device) + self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations) + self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device) + self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations) + + def forward(self, x, mask=None, optimized_attention=None): + x += self.self_attn(self.layer_norm1(x), mask, optimized_attention) + x += self.mlp(self.layer_norm2(x)) + return x + + +class CLIPEncoder(torch.nn.Module): + def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations): + super().__init__() + self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)]) + + def forward(self, x, mask=None, intermediate_output=None): + optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None) + + if intermediate_output is not None: + if intermediate_output < 0: + intermediate_output = len(self.layers) + intermediate_output + + intermediate = None + for i, l in enumerate(self.layers): + x = l(x, mask, optimized_attention) + if i == intermediate_output: + intermediate = x.clone() + return x, intermediate + +class CLIPEmbeddings(torch.nn.Module): + def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None): + super().__init__() + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device) + self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens): + return self.token_embedding(input_tokens) + self.position_embedding.weight + + +class CLIPTextModel_(torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + num_layers = config_dict["num_hidden_layers"] + embed_dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + intermediate_size = config_dict["intermediate_size"] + intermediate_activation = config_dict["hidden_act"] + + super().__init__() + self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) + self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) + self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True): + x = self.embeddings(input_tokens) + mask = None + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) + mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) + + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + if mask is not None: + mask += causal_mask + else: + mask = causal_mask + + x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output) + x = self.final_layer_norm(x) + if i is not None and final_layer_norm_intermediate: + i = self.final_layer_norm(i) + + pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),] + return x, i, pooled_output + +class CLIPTextModel(torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self.num_layers = config_dict["num_hidden_layers"] + self.text_model = CLIPTextModel_(config_dict, dtype, device, operations) + self.dtype = dtype + + def get_input_embeddings(self): + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, embeddings): + self.text_model.embeddings.token_embedding = embeddings + + def forward(self, *args, **kwargs): + return self.text_model(*args, **kwargs) + +class CLIPVisionEmbeddings(torch.nn.Module): + def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None): + super().__init__() + self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device)) + + self.patch_embedding = operations.Conv2d( + in_channels=num_channels, + out_channels=embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=False, + dtype=dtype, + device=device + ) + + num_patches = (image_size // patch_size) ** 2 + num_positions = num_patches + 1 + self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) + + def forward(self, pixel_values): + embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2) + return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device) + + +class CLIPVision(torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + num_layers = config_dict["num_hidden_layers"] + embed_dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + intermediate_size = config_dict["intermediate_size"] + intermediate_activation = config_dict["hidden_act"] + + self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=torch.float32, device=device, operations=operations) + self.pre_layrnorm = operations.LayerNorm(embed_dim) + self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) + self.post_layernorm = operations.LayerNorm(embed_dim) + + def forward(self, pixel_values, attention_mask=None, intermediate_output=None): + x = self.embeddings(pixel_values) + x = self.pre_layrnorm(x) + #TODO: attention_mask? + x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output) + pooled_output = self.post_layernorm(x[:, 0, :]) + return x, i, pooled_output + +class CLIPVisionModelProjection(torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self.vision_model = CLIPVision(config_dict, dtype, device, operations) + self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False) + + def forward(self, *args, **kwargs): + x = self.vision_model(*args, **kwargs) + out = self.visual_projection(x[2]) + return (x[0], x[1], out) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 1b109ff45..d3fdd5223 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,36 +1,43 @@ -from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils from .utils import load_torch_file, transformers_convert import os import torch -import contextlib +import json + from . import ops from . import model_patcher from . import model_management +from . import clip_model + + +class Output: + def __getitem__(self, key): + return getattr(self, key) + def __setitem__(self, key, item): + setattr(self, key, item) def clip_preprocess(image, size=224): mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype) std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype) - scale = (size / min(image.shape[1], image.shape[2])) - image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True) - h = (image.shape[2] - size)//2 - w = (image.shape[3] - size)//2 - image = image[:,:,h:h+size,w:w+size] + image = image.movedim(-1, 1) + if not (image.shape[2] == size and image.shape[3] == size): + scale = (size / min(image.shape[2], image.shape[3])) + image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True) + h = (image.shape[2] - size)//2 + w = (image.shape[3] - size)//2 + image = image[:,:,h:h+size,w:w+size] image = torch.clip((255. * image), 0, 255).round() / 255.0 return (image - mean.view([3,1,1])) / std.view([3,1,1]) class ClipVisionModel(): def __init__(self, json_config): - config = CLIPVisionConfig.from_json_file(json_config) + with open(json_config) as f: + config = json.load(f) + self.load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() - self.dtype = torch.float32 - if model_management.should_use_fp16(self.load_device, prioritize_performance=False): - self.dtype = torch.float16 - - with ops.use_comfy_ops(offload_device, self.dtype): - with modeling_utils.no_init_weights(): - self.model = CLIPVisionModelWithProjection(config) - self.model.to(self.dtype) + self.dtype = model_management.text_encoder_dtype(self.load_device) + self.model = clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, ops.manual_cast) + self.model.eval() self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): @@ -38,25 +45,13 @@ class ClipVisionModel(): def encode_image(self, image): model_management.load_model_gpu(self.patcher) - pixel_values = clip_preprocess(image.to(self.load_device)) - - if self.dtype != torch.float32: - precision_scope = torch.autocast - else: - precision_scope = lambda a, b: contextlib.nullcontext(a) - - with precision_scope(model_management.get_autocast_device(self.load_device), torch.float32): - outputs = self.model(pixel_values=pixel_values, output_hidden_states=True) - - for k in outputs: - t = outputs[k] - if t is not None: - if k == 'hidden_states': - outputs["penultimate_hidden_states"] = t[-2].cpu() - outputs["hidden_states"] = None - else: - outputs[k] = t.cpu() + pixel_values = clip_preprocess(image.to(self.load_device)).float() + out = self.model(pixel_values=pixel_values, intermediate_output=-2) + outputs = Output() + outputs["last_hidden_state"] = out[0].to(model_management.intermediate_device()) + outputs["image_embeds"] = out[2].to(model_management.intermediate_device()) + outputs["penultimate_hidden_states"] = out[1].to(model_management.intermediate_device()) return outputs def convert_to_transformers(sd, prefix): @@ -86,6 +81,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): if convert_keys: sd = convert_to_transformers(sd, prefix) if "vision_model.encoder.layers.47.layer_norm1.weight" in sd: + # todo: fix the importlib issue here json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json") elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 825953552..59ba52cf9 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -13,6 +13,8 @@ import typing from dataclasses import dataclass from typing import Tuple import sys +import gc +import inspect import torch @@ -442,6 +444,8 @@ class PromptExecutor: for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) self.server.last_node_id = None + if model_management.DISABLE_SMART_MEMORY: + model_management.unload_all_models() @@ -462,6 +466,14 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t errors = [] valid = True + # todo: investigate if these are at the right indent level + info = None + val = None + + validate_function_inputs = [] + if hasattr(obj_class, "VALIDATE_INPUTS"): + validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args + for x in required_inputs: if x not in inputs: error = { @@ -591,29 +603,7 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t errors.append(error) continue - if hasattr(obj_class, "VALIDATE_INPUTS"): - input_data_all = get_input_data(inputs, obj_class, unique_id) - # ret = obj_class.VALIDATE_INPUTS(**input_data_all) - ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") - for i, r3 in enumerate(ret): - if r3 is not True: - details = f"{x}" - if r3 is not False: - details += f" - {str(r3)}" - - error = { - "type": "custom_validation_failed", - "message": "Custom validation failed for node", - "details": details, - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } - } - errors.append(error) - continue - else: + if x not in validate_function_inputs: if isinstance(type_input, list): if val not in type_input: input_config = info @@ -640,6 +630,35 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t errors.append(error) continue + if len(validate_function_inputs) > 0: + input_data_all = get_input_data(inputs, obj_class, unique_id) + input_filtered = {} + for x in input_data_all: + if x in validate_function_inputs: + input_filtered[x] = input_data_all[x] + + #ret = obj_class.VALIDATE_INPUTS(**input_filtered) + ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") + for x in input_filtered: + for i, r in enumerate(ret): + if r is not True: + details = f"{x}" + if r is not False: + details += f" - {str(r)}" + + error = { + "type": "custom_validation_failed", + "message": "Custom validation failed for node", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue + if len(errors) > 0 or valid is not True: ret = (False, errors, unique_id) else: @@ -771,7 +790,7 @@ class PromptQueue: self.server.queue_updated() self.not_empty.notify() - def get(self, timeout=None) -> typing.Tuple[QueueTuple, int]: + def get(self, timeout=None) -> typing.Optional[typing.Tuple[QueueTuple, int]]: with self.not_empty: while len(self.queue) == 0: self.not_empty.wait(timeout=timeout) diff --git a/comfy/cmd/folder_paths.py b/comfy/cmd/folder_paths.py index c8ad20621..eed7577dd 100644 --- a/comfy/cmd/folder_paths.py +++ b/comfy/cmd/folder_paths.py @@ -188,8 +188,7 @@ def cached_filename_list_(folder_name): if folder_name not in filename_list_cache: return None out = filename_list_cache[folder_name] - if time.perf_counter() < (out[2] + 0.5): - return out + for x in out[1]: time_modified = out[1][x] folder = x diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index bdf77a039..38abd96fe 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -23,9 +23,9 @@ def execute_prestartup_script(): return False node_paths = folder_paths.get_folder_paths("custom_nodes") + node_prestartup_times = [] for custom_node_path in node_paths: possible_modules = os.listdir(custom_node_path) if os.path.exists(custom_node_path) else [] - node_prestartup_times = [] for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) @@ -69,6 +69,10 @@ if args.cuda_device is not None: os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) print("Set cuda device to:", args.cuda_device) +if args.deterministic: + if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" + from .. import utils import yaml @@ -78,12 +82,12 @@ from .server import BinaryEventTypes from .. import model_management -def prompt_worker(q: execution.PromptQueue, _server: server_module.PromptServer): +def prompt_worker(q, _server): e = execution.PromptExecutor(_server) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 - + current_time = 0.0 while True: timeout = None if need_gc: @@ -94,11 +98,13 @@ def prompt_worker(q: execution.PromptQueue, _server: server_module.PromptServer) item, item_id = queue_item execution_start_time = time.perf_counter() prompt_id = item[1] + _server.last_prompt_id = prompt_id + e.execute(item[2], prompt_id, item[3], item[4]) need_gc = True q.task_done(item_id, e.outputs_ui) if _server.client_id is not None: - _server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, _server.client_id) + _server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, _server.client_id) current_time = time.perf_counter() execution_time = current_time - execution_start_time @@ -119,7 +125,10 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None): def hijack_progress(server): def hook(value, total, preview_image): - server.send_sync("progress", {"value": value, "max": total}, server.client_id) + model_management.throw_exception_if_processing_interrupted() + progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id} + + server.send_sync("progress", progress, server.client_id) if preview_image is not None: server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) @@ -204,7 +213,7 @@ def main(): print(f"Setting output directory to: {output_dir}") folder_paths.set_output_directory(output_dir) - #These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes + # These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints")) folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip")) folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae")) diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 78407adeb..06a621591 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -734,7 +734,8 @@ class PromptServer(): message = self.encode_bytes(event, data) if sid is None: - for ws in self.sockets.values(): + sockets = list(self.sockets.values()) + for ws in sockets: await send_socket_catch_exception(ws.send_bytes, message) elif sid in self.sockets: await send_socket_catch_exception(self.sockets[sid].send_bytes, message) @@ -743,7 +744,8 @@ class PromptServer(): message = {"type": event, "data": data} if sid is None: - for ws in self.sockets.values(): + sockets = list(self.sockets.values()) + for ws in sockets: await send_socket_catch_exception(ws.send_json, message) elif sid in self.sockets: await send_socket_catch_exception(self.sockets[sid].send_json, message) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 919dfbea3..0b11ad1f7 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -1,10 +1,13 @@ import torch import math import os +import contextlib + from . import utils from . import model_management from . import model_detection from . import model_patcher +from . import ops from .cldm import cldm from .t2i_adapter import adapter @@ -34,13 +37,13 @@ class ControlBase: self.cond_hint = None self.strength = 1.0 self.timestep_percent_range = (0.0, 1.0) + self.global_average_pooling = False self.timestep_range = None if device is None: device = model_management.get_torch_device() self.device = device self.previous_controlnet = None - self.global_average_pooling = False def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)): self.cond_hint_original = cond_hint @@ -75,6 +78,7 @@ class ControlBase: c.cond_hint_original = self.cond_hint_original c.strength = self.strength c.timestep_percent_range = self.timestep_percent_range + c.global_average_pooling = self.global_average_pooling def inference_memory_requirements(self, dtype): if self.previous_controlnet is not None: @@ -127,12 +131,14 @@ class ControlBase: return out class ControlNet(ControlBase): - def __init__(self, control_model, global_average_pooling=False, device=None): + def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): super().__init__(device) self.control_model = control_model - self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) + self.load_device = load_device + self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device()) self.global_average_pooling = global_average_pooling self.model_sampling_current = None + self.manual_cast_dtype = manual_cast_dtype def get_control(self, x_noisy, t, cond, batched_number): control_prev = None @@ -146,28 +152,31 @@ class ControlNet(ControlBase): else: return None + dtype = self.control_model.dtype + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype + output_dtype = x_noisy.dtype if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device) if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - context = cond['c_crossattn'] y = cond.get('y', None) if y is not None: - y = y.to(self.control_model.dtype) + y = y.to(dtype) timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y) + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) return self.control_merge(None, control, control_prev, output_dtype) def copy(self): - c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) + c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) self.copy_to(c) return c @@ -198,10 +207,11 @@ class ControlLoraOps: self.bias = None def forward(self, input): + weight, bias = ops.cast_bias_weight(self, input) if self.up is not None: - return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) + return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) else: - return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias) + return torch.nn.functional.linear(input, weight, bias) class Conv2d(torch.nn.Module): def __init__( @@ -237,16 +247,11 @@ class ControlLoraOps: def forward(self, input): + weight, bias = ops.cast_bias_weight(self, input) if self.up is not None: - return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) + return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) else: - return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) - - def conv_nd(self, dims, *args, **kwargs): - if dims == 2: - return self.Conv2d(*args, **kwargs) - else: - raise ValueError(f"unsupported dimensions: {dims}") + return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) class ControlLora(ControlNet): @@ -260,17 +265,26 @@ class ControlLora(ControlNet): controlnet_config = model.model_config.unet_config.copy() controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1] - controlnet_config["operations"] = ControlLoraOps() - self.control_model = cldm.ControlNet(**controlnet_config) + self.manual_cast_dtype = model.manual_cast_dtype dtype = model.get_dtype() - self.control_model.to(dtype) + if self.manual_cast_dtype is None: + class control_lora_ops(ControlLoraOps, ops.disable_weight_init): + pass + else: + class control_lora_ops(ControlLoraOps, ops.manual_cast): + pass + dtype = self.manual_cast_dtype + + controlnet_config["operations"] = control_lora_ops + controlnet_config["dtype"] = dtype + self.control_model = cldm.ControlNet(**controlnet_config) self.control_model.to(model_management.get_torch_device()) diffusion_model = model.diffusion_model sd = diffusion_model.state_dict() cm = self.control_model.state_dict() for k in sd: - weight = model_management.resolve_lowvram_weight(sd[k], diffusion_model, k) + weight = sd[k] try: utils.set_attr(self.control_model, k, weight) except: @@ -367,6 +381,10 @@ def load_controlnet(ckpt_path, model=None): if controlnet_config is None: unet_dtype = model_management.unet_dtype() controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config + load_device = model_management.get_torch_device() + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) + if manual_cast_dtype is not None: + controlnet_config["operations"] = ops.manual_cast controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] control_model = cldm.ControlNet(**controlnet_config) @@ -395,14 +413,12 @@ def load_controlnet(ckpt_path, model=None): missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) print(missing, unexpected) - control_model = control_model.to(unet_dtype) - global_average_pooling = False filename = os.path.splitext(ckpt_path)[0] if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling global_average_pooling = True - control = ControlNet(control_model, global_average_pooling=global_average_pooling) + control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return control class T2IAdapter(ControlBase): diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index c209087e0..2252a075e 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -33,3 +33,7 @@ class SDXL(LatentFormat): [-0.3112, -0.2359, -0.2076] ] self.taesd_decoder_name = "taesdxl_decoder" + +class SD_X4(LatentFormat): + def __init__(self): + self.scale_factor = 0.08333 diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index 0f9ac17cd..f85c3369a 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -9,6 +9,7 @@ from ..modules.distributions.distributions import DiagonalGaussianDistribution from ..util import instantiate_from_config, get_obj_from_str from ..modules.ema import LitEma +from ... import ops class DiagonalGaussianRegularizer(torch.nn.Module): def __init__(self, sample: bool = True): @@ -162,12 +163,12 @@ class AutoencodingEngineLegacy(AutoencodingEngine): }, **kwargs, ) - self.quant_conv = torch.nn.Conv2d( + self.quant_conv = ops.disable_weight_init.Conv2d( (1 + ddconfig["double_z"]) * ddconfig["z_channels"], (1 + ddconfig["double_z"]) * embed_dim, 1, ) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.post_quant_conv = ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim def get_autoencoder_params(self) -> list: diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 97d03a589..17dda7c58 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -18,6 +18,7 @@ if model_management.xformers_enabled(): from ...cli_args import args from ... import ops +ops = ops.disable_weight_init # CrossAttn precision handling if args.dont_upcast_attention: @@ -82,16 +83,6 @@ class FeedForward(nn.Module): def forward(self, x): return self.net(x) - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) @@ -112,19 +103,20 @@ def attention_basic(q, k, v, heads, mask=None): # force cast to fp32 to avoid overflowing if _ATTN_PRECISION =="fp32": - with torch.autocast(enabled=False, device_type = 'cuda'): - q, k = q.float(), k.float() - sim = einsum('b i d, b j d -> b i j', q, k) * scale + sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale else: sim = einsum('b i d, b j d -> b i j', q, k) * scale del q, k if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) + if mask.dtype == torch.bool: + mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + else: + sim += mask # attention, what we cannot get enough of sim = sim.softmax(dim=-1) @@ -349,6 +341,18 @@ else: if model_management.pytorch_attention_enabled(): optimized_attention_masked = attention_pytorch +def optimized_attention_for_device(device, mask=False): + if device == torch.device("cpu"): #TODO + if model_management.pytorch_attention_enabled(): + return attention_pytorch + else: + return attention_basic + if mask: + return optimized_attention_masked + + return optimized_attention + + class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops): super().__init__() @@ -393,7 +397,7 @@ class BasicTransformerBlock(nn.Module): self.is_res = inner_dim == dim if self.ff_in: - self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device) + self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device) self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) self.disable_self_attn = disable_self_attn @@ -413,10 +417,10 @@ class BasicTransformerBlock(nn.Module): self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none - self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) - self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) - self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.checkpoint = checkpoint self.n_heads = n_heads self.d_head = d_head @@ -558,7 +562,7 @@ class SpatialTransformer(nn.Module): context_dim = [context_dim] * depth self.in_channels = in_channels inner_dim = n_heads * d_head - self.norm = Normalize(in_channels, dtype=dtype, device=device) + self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) if not use_linear: self.proj_in = operations.Conv2d(in_channels, inner_dim, diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 943a3dd1a..7d66ef4c3 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -8,6 +8,7 @@ from typing import Optional, Any from .... import model_management from .... import ops +ops = ops.disable_weight_init if model_management.xformers_enabled_vae(): import xformers @@ -40,7 +41,7 @@ def nonlinearity(x): def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) class Upsample(nn.Module): diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 222435ba0..903dc2801 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -12,13 +12,13 @@ from .util import ( checkpoint, avg_pool_nd, zero_module, - normalization, timestep_embedding, AlphaBlender, ) from ..attention import SpatialTransformer, SpatialVideoTransformer, default from ...util import exists from .... import ops +ops = ops.disable_weight_init class TimestepBlock(nn.Module): """ @@ -177,7 +177,7 @@ class ResBlock(TimestepBlock): padding = kernel_size // 2 self.in_layers = nn.Sequential( - nn.GroupNorm(32, channels, dtype=dtype, device=device), + operations.GroupNorm(32, channels, dtype=dtype, device=device), nn.SiLU(), operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device), ) @@ -206,12 +206,11 @@ class ResBlock(TimestepBlock): ), ) self.out_layers = nn.Sequential( - nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device), + operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device), nn.SiLU(), nn.Dropout(p=dropout), - zero_module( - operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device) - ), + operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device) + , ) if self.out_channels == channels: @@ -438,9 +437,6 @@ class UNetModel(nn.Module): operations=ops, ): super().__init__() - assert use_spatial_transformer == True, "use_spatial_transformer has to be true" - if use_spatial_transformer: - assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' if context_dim is not None: assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' @@ -457,7 +453,6 @@ class UNetModel(nn.Module): if num_head_channels == -1: assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' - self.image_size = image_size self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels @@ -503,7 +498,7 @@ class UNetModel(nn.Module): if self.num_classes is not None: if isinstance(self.num_classes, int): - self.label_emb = nn.Embedding(num_classes, time_embed_dim) + self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device) elif self.num_classes == "continuous": print("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) @@ -810,13 +805,13 @@ class UNetModel(nn.Module): self._feature_size += ch self.out = nn.Sequential( - nn.GroupNorm(32, ch, dtype=self.dtype, device=device), + operations.GroupNorm(32, ch, dtype=self.dtype, device=device), nn.SiLU(), zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)), ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( - nn.GroupNorm(32, ch, dtype=self.dtype, device=device), + operations.GroupNorm(32, ch, dtype=self.dtype, device=device), operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device), #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits ) @@ -842,14 +837,14 @@ class UNetModel(nn.Module): self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) emb = self.time_embed(t_emb) if self.num_classes is not None: assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) - h = x.type(self.dtype) + h = x for id, module in enumerate(self.input_blocks): transformer_options["block"] = ("input", id) h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) diff --git a/comfy/ldm/modules/diffusionmodules/upscaling.py b/comfy/ldm/modules/diffusionmodules/upscaling.py index 7a86adeb4..9652cfc76 100644 --- a/comfy/ldm/modules/diffusionmodules/upscaling.py +++ b/comfy/ldm/modules/diffusionmodules/upscaling.py @@ -41,10 +41,14 @@ class AbstractLowScaleModel(nn.Module): self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + def q_sample(self, x_start, t, noise=None, seed=None): + if noise is None: + if seed is None: + noise = torch.randn_like(x_start) + else: + noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device) + return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise) def forward(self, x): return x, None @@ -69,12 +73,12 @@ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): super().__init__(noise_schedule_config=noise_schedule_config) self.max_noise_level = max_noise_level - def forward(self, x, noise_level=None): + def forward(self, x, noise_level=None, seed=None): if noise_level is None: noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() else: assert isinstance(noise_level, torch.Tensor) - z = self.q_sample(x, noise_level) + z = self.q_sample(x, noise_level, seed=seed) return z, noise_level diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 85cf433c3..88346dc2e 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -16,7 +16,6 @@ import numpy as np from einops import repeat, rearrange from ...util import instantiate_from_config -from .... import ops class AlphaBlender(nn.Module): strategies = ["learned", "fixed", "learned_with_images"] @@ -52,9 +51,9 @@ class AlphaBlender(nn.Module): if self.merge_strategy == "fixed": # make shape compatible # alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs) - alpha = self.mix_factor + alpha = self.mix_factor.to(image_only_indicator.device) elif self.merge_strategy == "learned": - alpha = torch.sigmoid(self.mix_factor) + alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device)) # make shape compatible # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) elif self.merge_strategy == "learned_with_images": @@ -62,7 +61,7 @@ class AlphaBlender(nn.Module): alpha = torch.where( image_only_indicator.bool(), torch.ones(1, 1, device=image_only_indicator.device), - rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), + rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"), ) alpha = rearrange(alpha, self.rearrange_pattern) # make shape compatible @@ -273,46 +272,6 @@ def mean_flat(tensor): return tensor.mean(dim=list(range(1, len(tensor.shape)))) -def normalization(channels, dtype=None): - """ - Make a standard normalization layer. - :param channels: number of input channels. - :return: an nn.Module for normalization. - """ - return GroupNorm32(32, channels, dtype=dtype) - - -# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. -class SiLU(nn.Module): - def forward(self, x): - return x * torch.sigmoid(x) - - -class GroupNorm32(nn.GroupNorm): - def forward(self, x): - return super().forward(x.float()).type(x.dtype) - - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return ops.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def linear(*args, **kwargs): - """ - Create a linear module. - """ - return ops.Linear(*args, **kwargs) - - def avg_pool_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D average pooling module. diff --git a/comfy/ldm/modules/encoders/noise_aug_modules.py b/comfy/ldm/modules/encoders/noise_aug_modules.py index b59bf204b..66767b587 100644 --- a/comfy/ldm/modules/encoders/noise_aug_modules.py +++ b/comfy/ldm/modules/encoders/noise_aug_modules.py @@ -15,12 +15,12 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): def scale(self, x): # re-normalize to centered mean and unit variance - x = (x - self.data_mean) * 1. / self.data_std + x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device) return x def unscale(self, x): # back to original data stats - x = (x * self.data_std) + self.data_mean + x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device) return x def forward(self, x, noise_level=None): diff --git a/comfy/ldm/modules/temporal_ae.py b/comfy/ldm/modules/temporal_ae.py index 04cb7eb00..c1966ea42 100644 --- a/comfy/ldm/modules/temporal_ae.py +++ b/comfy/ldm/modules/temporal_ae.py @@ -5,6 +5,7 @@ import torch from einops import rearrange, repeat from ... import ops +ops = ops.disable_weight_init from .diffusionmodules.model import ( AttnBlock, @@ -81,14 +82,14 @@ class VideoResBlock(ResnetBlock): x = self.time_stack(x, temb) - alpha = self.get_alpha(bs=b // timesteps) + alpha = self.get_alpha(bs=b // timesteps).to(x.device) x = alpha * x + (1.0 - alpha) * x_mix x = rearrange(x, "b c t h w -> (b t) c h w") return x -class AE3DConv(torch.nn.Conv2d): +class AE3DConv(ops.Conv2d): def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): super().__init__(in_channels, out_channels, *args, **kwargs) if isinstance(video_kernel_size, Iterable): @@ -96,7 +97,7 @@ class AE3DConv(torch.nn.Conv2d): else: padding = int(video_kernel_size // 2) - self.time_mix_conv = torch.nn.Conv3d( + self.time_mix_conv = ops.Conv3d( in_channels=out_channels, out_channels=out_channels, kernel_size=video_kernel_size, @@ -166,7 +167,7 @@ class AttnVideoBlock(AttnBlock): emb = emb[:, None, :] x_mix = x_mix + emb - alpha = self.get_alpha() + alpha = self.get_alpha().to(x.device) x_mix = self.time_mix_block(x_mix, timesteps=timesteps) x = alpha * x + (1.0 - alpha) * x_mix # alpha merge diff --git a/comfy/lora.py b/comfy/lora.py index 78e5d4ce8..a19e7161d 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -43,7 +43,7 @@ def load_lora(lora, to_load): if mid_name is not None and mid_name in lora.keys(): mid = lora[mid_name] loaded_keys.add(mid_name) - patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid) + patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid)) loaded_keys.add(A_name) loaded_keys.add(B_name) @@ -64,7 +64,7 @@ def load_lora(lora, to_load): loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t2_name) - patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2) + patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)) loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w2_a_name) @@ -116,8 +116,19 @@ def load_lora(lora, to_load): loaded_keys.add(lokr_t2_name) if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2) + patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)) + #glora + a1_name = "{}.a1.weight".format(x) + a2_name = "{}.a2.weight".format(x) + b1_name = "{}.b1.weight".format(x) + b2_name = "{}.b2.weight".format(x) + if a1_name in lora: + patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha)) + loaded_keys.add(a1_name) + loaded_keys.add(a2_name) + loaded_keys.add(b1_name) + loaded_keys.add(b2_name) w_norm_name = "{}.w_norm".format(x) b_norm_name = "{}.b_norm".format(x) @@ -126,21 +137,21 @@ def load_lora(lora, to_load): if w_norm is not None: loaded_keys.add(w_norm_name) - patch_dict[to_load[x]] = (w_norm,) + patch_dict[to_load[x]] = ("diff", (w_norm,)) if b_norm is not None: loaded_keys.add(b_norm_name) - patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,) + patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,)) diff_name = "{}.diff".format(x) diff_weight = lora.get(diff_name, None) if diff_weight is not None: - patch_dict[to_load[x]] = (diff_weight,) + patch_dict[to_load[x]] = ("diff", (diff_weight,)) loaded_keys.add(diff_name) diff_bias_name = "{}.diff_b".format(x) diff_bias = lora.get(diff_bias_name, None) if diff_bias is not None: - patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,) + patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,)) loaded_keys.add(diff_bias_name) for x in lora.keys(): diff --git a/comfy/model_base.py b/comfy/model_base.py index 38bdfe21b..7a0802e2b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1,10 +1,12 @@ import torch -from .ldm.modules.diffusionmodules.openaimodel import UNetModel +from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation -from .ldm.modules.diffusionmodules.openaimodel import Timestep +from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from . import model_management from . import conds +from . import ops from enum import Enum +import contextlib from . import utils class ModelType(Enum): @@ -13,7 +15,7 @@ class ModelType(Enum): V_PREDICTION_EDM = 3 -from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM +from .model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM def model_sampling(model_config, model_type): @@ -40,9 +42,14 @@ class BaseModel(torch.nn.Module): unet_config = model_config.unet_config self.latent_format = model_config.latent_format self.model_config = model_config + self.manual_cast_dtype = model_config.manual_cast_dtype if not unet_config.get("disable_unet_model_creation", False): - self.diffusion_model = UNetModel(**unet_config, device=device) + if self.manual_cast_dtype is not None: + operations = ops.manual_cast + else: + operations = ops.disable_weight_init + self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations) self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) @@ -61,15 +68,21 @@ class BaseModel(torch.nn.Module): context = c_crossattn dtype = self.get_dtype() + + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype + xc = xc.to(dtype) t = self.model_sampling.timestep(t).float() context = context.to(dtype) extra_conds = {} for o in kwargs: extra = kwargs[o] - if hasattr(extra, "to"): - extra = extra.to(dtype) + if hasattr(extra, "dtype"): + if extra.dtype != torch.int and extra.dtype != torch.long: + extra = extra.to(dtype) extra_conds[o] = extra + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() return self.model_sampling.calculate_denoised(sigma, model_output, x) @@ -117,6 +130,10 @@ class BaseModel(torch.nn.Module): adm = self.encode_adm(**kwargs) if adm is not None: out['y'] = conds.CONDRegular(adm) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = conds.CONDCrossAttn(cross_attn) return out def load_model_weights(self, sd, unet_prefix=""): @@ -144,11 +161,7 @@ class BaseModel(torch.nn.Module): def state_dict_for_saving(self, clip_state_dict, vae_state_dict): clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) - unet_sd = self.diffusion_model.state_dict() - unet_state_dict = {} - for k in unet_sd: - unet_state_dict[k] = model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k) - + unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) if self.get_dtype() == torch.float16: @@ -165,9 +178,12 @@ class BaseModel(torch.nn.Module): def memory_required(self, input_shape): if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention(): + dtype = self.get_dtype() + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype #TODO: this needs to be tweaked area = input_shape[0] * input_shape[2] * input_shape[3] - return (area * model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024) + return (area * model_management.dtype_size(dtype) / 50) * (1024 * 1024) else: #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. area = input_shape[0] * input_shape[2] * input_shape[3] @@ -307,9 +323,75 @@ class SVD_img2vid(BaseModel): out['c_concat'] = conds.CONDNoiseShape(latent_image) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = conds.CONDCrossAttn(cross_attn) + if "time_conditioning" in kwargs: out["time_context"] = conds.CONDCrossAttn(kwargs["time_conditioning"]) out['image_only_indicator'] = conds.CONDConstant(torch.zeros((1,), device=device)) out['num_video_frames'] = conds.CONDConstant(noise.shape[0]) return out + +class Stable_Zero123(BaseModel): + def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None): + super().__init__(model_config, model_type, device=device) + self.cc_projection = ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device) + self.cc_projection.weight.copy_(cc_projection_weight) + self.cc_projection.bias.copy_(cc_projection_bias) + + def extra_conds(self, **kwargs): + out = {} + + latent_image = kwargs.get("concat_latent_image", None) + noise = kwargs.get("noise", None) + + if latent_image is None: + latent_image = torch.zeros_like(noise) + + if latent_image.shape[1:] != noise.shape[1:]: + latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") + + latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0]) + + out['c_concat'] = conds.CONDNoiseShape(latent_image) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + if cross_attn.shape[-1] != 768: + cross_attn = self.cc_projection(cross_attn) + out['c_crossattn'] = conds.CONDCrossAttn(cross_attn) + return out + +class SD_X4Upscaler(BaseModel): + def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): + super().__init__(model_config, model_type, device=device) + self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350) + + def extra_conds(self, **kwargs): + out = {} + + image = kwargs.get("concat_image", None) + noise = kwargs.get("noise", None) + noise_augment = kwargs.get("noise_augmentation", 0.0) + device = kwargs["device"] + seed = kwargs["seed"] - 10 + + noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment) + + if image is None: + image = torch.zeros_like(noise)[:,:3] + + if image.shape[1:] != noise.shape[1:]: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + + noise_level = torch.tensor([noise_level], device=device) + if noise_augment > 0: + image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed) + + image = utils.resize_to_batch_size(image, noise.shape[0]) + + out['c_concat'] = conds.CONDNoiseShape(image) + out['y'] = conds.CONDRegular(noise_level) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index d08941bb9..13ac3bdc4 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -34,7 +34,6 @@ def detect_unet_config(state_dict, key_prefix, dtype): unet_config = { "use_checkpoint": False, "image_size": 32, - "out_channels": 4, "use_spatial_transformer": True, "legacy": False } @@ -50,6 +49,12 @@ def detect_unet_config(state_dict, key_prefix, dtype): model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] + out_key = '{}out.2.weight'.format(key_prefix) + if out_key in state_dict: + out_channels = state_dict[out_key].shape[0] + else: + out_channels = 4 + num_res_blocks = [] channel_mult = [] attention_resolutions = [] @@ -122,6 +127,7 @@ def detect_unet_config(state_dict, key_prefix, dtype): transformer_depth_middle = -1 unet_config["in_channels"] = in_channels + unet_config["out_channels"] = out_channels unet_config["model_channels"] = model_channels unet_config["num_res_blocks"] = num_res_blocks unet_config["transformer_depth"] = transformer_depth @@ -289,7 +295,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype): 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'use_temporal_attention': False, 'use_temporal_resblock': False} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B] + Segmind_Vega = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 1, 1, 2, 2], 'transformer_depth_output': [0, 0, 0, 1, 1, 1, 2, 2, 2], + 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, + 'use_temporal_attention': False, 'use_temporal_resblock': False} + + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega] for unet_config in supported_models: matches = True diff --git a/comfy/model_management.py b/comfy/model_management.py index 068e9d162..6d6800e86 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -2,6 +2,7 @@ import psutil from enum import Enum from .cli_args import args from . import utils + import torch import sys @@ -28,6 +29,10 @@ total_vram = 0 lowvram_available = True xpu_available = False +if args.deterministic: + print("Using deterministic algorithms for pytorch") + torch.use_deterministic_algorithms(True, warn_only=True) + directml_enabled = False if args.directml is not None: import torch_directml @@ -182,6 +187,9 @@ except: if is_intel_xpu(): VAE_DTYPE = torch.bfloat16 +if args.cpu_vae: + VAE_DTYPE = torch.float32 + if args.fp16_vae: VAE_DTYPE = torch.float16 elif args.bf16_vae: @@ -214,15 +222,8 @@ if args.force_fp16 or cpu_state == CPUState.MPS: FORCE_FP16 = True if lowvram_available: - try: - import accelerate - if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): - vram_state = set_vram_to - except Exception as e: - import traceback - print(traceback.format_exc()) - print("ERROR: LOW VRAM MODE NEEDS accelerate.") - lowvram_available = False + if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): + vram_state = set_vram_to if cpu_state != CPUState.GPU: @@ -262,6 +263,14 @@ print("VAE dtype:", VAE_DTYPE) current_loaded_models = [] +def module_size(module): + module_mem = 0 + sd = module.state_dict() + for k in sd: + t = sd[k] + module_mem += t.nelement() * t.element_size() + return module_mem + class LoadedModel: def __init__(self, model): self.model = model @@ -294,8 +303,20 @@ class LoadedModel: if lowvram_model_memory > 0: print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024)) - device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) - accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device) + mem_counter = 0 + for m in self.real_model.modules(): + if hasattr(m, "comfy_cast_weights"): + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + module_mem = module_size(m) + if mem_counter + module_mem < lowvram_model_memory: + m.to(self.device) + mem_counter += module_mem + elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode + m.to(self.device) + mem_counter += module_size(m) + print("lowvram: loaded module regularly", m) + self.model_accelerated = True if is_intel_xpu() and not args.disable_ipex_optimize: @@ -305,7 +326,11 @@ class LoadedModel: def model_unload(self): if self.model_accelerated: - accelerate.hooks.remove_hook_from_submodules(self.real_model) + for m in self.real_model.modules(): + if hasattr(m, "prev_comfy_cast_weights"): + m.comfy_cast_weights = m.prev_comfy_cast_weights + del m.prev_comfy_cast_weights + self.model_accelerated = False self.model.unpatch_model(self.model.offload_device) @@ -398,14 +423,14 @@ def load_models_gpu(models, memory_required=0): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) + lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary vram_set_state = VRAMState.LOW_VRAM else: lowvram_model_memory = 0 if vram_set_state == VRAMState.NO_VRAM: - lowvram_model_memory = 256 * 1024 * 1024 + lowvram_model_memory = 64 * 1024 * 1024 cur_loaded_model = loaded_model.model_load(lowvram_model_memory) current_loaded_models.insert(0, loaded_model) @@ -430,6 +455,13 @@ def dtype_size(dtype): dtype_size = 4 if dtype == torch.float16 or dtype == torch.bfloat16: dtype_size = 2 + elif dtype == torch.float32: + dtype_size = 4 + else: + try: + dtype_size = dtype.itemsize + except: #Old pytorch doesn't have .itemsize + pass return dtype_size def unet_offload_device(): @@ -459,10 +491,30 @@ def unet_inital_load_device(parameters, dtype): def unet_dtype(device=None, model_params=0): if args.bf16_unet: return torch.bfloat16 + if args.fp16_unet: + return torch.float16 + if args.fp8_e4m3fn_unet: + return torch.float8_e4m3fn + if args.fp8_e5m2_unet: + return torch.float8_e5m2 if should_use_fp16(device=device, model_params=model_params): return torch.float16 return torch.float32 +# None means no manual cast +def unet_manual_cast(weight_dtype, inference_device): + if weight_dtype == torch.float32: + return None + + fp16_supported = should_use_fp16(inference_device, prioritize_performance=False) + if fp16_supported and weight_dtype == torch.float16: + return None + + if fp16_supported: + return torch.float16 + else: + return torch.float32 + def text_encoder_offload_device(): if args.gpu_only: return get_torch_device() @@ -492,12 +544,23 @@ def text_encoder_dtype(device=None): elif args.fp32_text_enc: return torch.float32 + if is_device_cpu(device): + return torch.float16 + if should_use_fp16(device, prioritize_performance=False): return torch.float16 else: return torch.float32 +def intermediate_device(): + if args.gpu_only: + return get_torch_device() + else: + return torch.device("cpu") + def vae_device(): + if args.cpu_vae: + return torch.device("cpu") return get_torch_device() def vae_offload_device(): @@ -515,6 +578,22 @@ def get_autocast_device(dev): return dev.type return "cuda" +def supports_dtype(device, dtype): #TODO + if dtype == torch.float32: + return True + if is_device_cpu(device): + return False + if dtype == torch.float16: + return True + if dtype == torch.bfloat16: + return True + return False + +def device_supports_non_blocking(device): + if is_device_mps(device): + return False #pytorch bug? mps doesn't support non blocking + return True + def cast_to_device(tensor, device, dtype, copy=False): device_supports_cast = False if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: @@ -525,15 +604,17 @@ def cast_to_device(tensor, device, dtype, copy=False): elif is_intel_xpu(): device_supports_cast = True + non_blocking = device_supports_non_blocking(device) + if device_supports_cast: if copy: if tensor.device == device: - return tensor.to(dtype, copy=copy) - return tensor.to(device, copy=copy).to(dtype) + return tensor.to(dtype, copy=copy, non_blocking=non_blocking) + return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking) else: - return tensor.to(device).to(dtype) + return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking) else: - return tensor.to(dtype).to(device, copy=copy) + return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking) def xformers_enabled(): global directml_enabled @@ -687,11 +768,11 @@ def soft_empty_cache(force=False): torch.cuda.empty_cache() torch.cuda.ipc_collect() -def resolve_lowvram_weight(weight, model, key): - if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. - key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device. - op = utils.get_attr(model, '.'.join(key_split[:-1])) - weight = op._hf_hook.weights_map[key_split[-1]] +def unload_all_models(): + free_memory(1e30, get_torch_device()) + + +def resolve_lowvram_weight(weight, model, key): #TODO: remove return weight #TODO: might be cleaner to put this somewhere else diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c8fa1b358..814537171 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -28,13 +28,9 @@ class ModelPatcher: if self.size > 0: return self.size model_sd = self.model.state_dict() - size = 0 - for k in model_sd: - t = model_sd[k] - size += t.nelement() * t.element_size() - self.size = size + self.size = model_management.module_size(self.model) self.model_keys = set(model_sd.keys()) - return size + return self.size def clone(self): n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) @@ -55,11 +51,18 @@ class ModelPatcher: def memory_required(self, input_shape): return self.model.memory_required(input_shape=input_shape) - def set_model_sampler_cfg_function(self, sampler_cfg_function): + def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False): if len(inspect.signature(sampler_cfg_function).parameters) == 3: self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way else: self.model_options["sampler_cfg_function"] = sampler_cfg_function + if disable_cfg1_optimization: + self.model_options["disable_cfg1_optimization"] = True + + def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False): + self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function] + if disable_cfg1_optimization: + self.model_options["disable_cfg1_optimization"] = True def set_model_unet_function_wrapper(self, unet_wrapper_function): self.model_options["model_function_wrapper"] = unet_wrapper_function @@ -70,13 +73,17 @@ class ModelPatcher: to["patches"] = {} to["patches"][name] = to["patches"].get(name, []) + [patch] - def set_model_patch_replace(self, patch, name, block_name, number): + def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None): to = self.model_options["transformer_options"] if "patches_replace" not in to: to["patches_replace"] = {} if name not in to["patches_replace"]: to["patches_replace"][name] = {} - to["patches_replace"][name][(block_name, number)] = patch + if transformer_index is not None: + block = (block_name, number, transformer_index) + else: + block = (block_name, number) + to["patches_replace"][name][block] = patch def set_model_attn1_patch(self, patch): self.set_model_patch(patch, "attn1_patch") @@ -84,11 +91,11 @@ class ModelPatcher: def set_model_attn2_patch(self, patch): self.set_model_patch(patch, "attn2_patch") - def set_model_attn1_replace(self, patch, block_name, number): - self.set_model_patch_replace(patch, "attn1", block_name, number) + def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None): + self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index) - def set_model_attn2_replace(self, patch, block_name, number): - self.set_model_patch_replace(patch, "attn2", block_name, number) + def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None): + self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index) def set_model_attn1_output_patch(self, patch): self.set_model_patch(patch, "attn1_output_patch") @@ -167,40 +174,41 @@ class ModelPatcher: sd.pop(k) return sd - def patch_model(self, device_to=None): + def patch_model(self, device_to=None, patch_weights=True): for k in self.object_patches: old = getattr(self.model, k) if k not in self.object_patches_backup: self.object_patches_backup[k] = old setattr(self.model, k, self.object_patches[k]) - model_sd = self.model_state_dict() - for key in self.patches: - if key not in model_sd: - print("could not patch. key doesn't exist in model:", key) - continue + if patch_weights: + model_sd = self.model_state_dict() + for key in self.patches: + if key not in model_sd: + print("could not patch. key doesn't exist in model:", key) + continue - weight = model_sd[key] + weight = model_sd[key] - inplace_update = self.weight_inplace_update + inplace_update = self.weight_inplace_update - if key not in self.backup: - self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) + if key not in self.backup: + self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) + + if device_to is not None: + temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True) + else: + temp_weight = weight.to(torch.float32, copy=True) + out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + if inplace_update: + utils.copy_to_param(self.model, key, out_weight) + else: + utils.set_attr(self.model, key, out_weight) + del temp_weight if device_to is not None: - temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True) - else: - temp_weight = weight.to(torch.float32, copy=True) - out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) - if inplace_update: - utils.copy_to_param(self.model, key, out_weight) - else: - utils.set_attr(self.model, key, out_weight) - del temp_weight - - if device_to is not None: - self.model.to(device_to) - self.current_device = device_to + self.model.to(device_to) + self.current_device = device_to return self.model @@ -217,13 +225,19 @@ class ModelPatcher: v = (self.calculate_weight(v[1:], v[0].clone(), key), ) if len(v) == 1: + patch_type = "diff" + elif len(v) == 2: + patch_type = v[0] + v = v[1] + + if patch_type == "diff": w1 = v[0] if alpha != 0.0: if w1.shape != weight.shape: print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) else: weight += alpha * model_management.cast_to_device(w1, weight.device, weight.dtype) - elif len(v) == 4: #lora/locon + elif patch_type == "lora": #lora/locon mat1 = model_management.cast_to_device(v[0], weight.device, torch.float32) mat2 = model_management.cast_to_device(v[1], weight.device, torch.float32) if v[2] is not None: @@ -237,7 +251,7 @@ class ModelPatcher: weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) except Exception as e: print("ERROR", key, e) - elif len(v) == 8: #lokr + elif patch_type == "lokr": w1 = v[0] w2 = v[1] w1_a = v[3] @@ -276,7 +290,7 @@ class ModelPatcher: weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) except Exception as e: print("ERROR", key, e) - else: #loha + elif patch_type == "loha": w1a = v[0] w1b = v[1] if v[2] is not None: @@ -305,6 +319,18 @@ class ModelPatcher: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) except Exception as e: print("ERROR", key, e) + elif patch_type == "glora": + if v[4] is not None: + alpha *= v[4] / v[0].shape[0] + + a1 = model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) + a2 = model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) + b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) + b2 = model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32) + + weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) + else: + print("patch type not recognized", patch_type, key) return weight diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 69c8b1f01..cc8745c10 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -22,10 +22,17 @@ class V_PREDICTION(EPS): class ModelSamplingDiscrete(torch.nn.Module): def __init__(self, model_config=None): super().__init__() - beta_schedule = "linear" + if model_config is not None: - beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule) - self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) + sampling_settings = model_config.sampling_settings + else: + sampling_settings = {} + + beta_schedule = sampling_settings.get("beta_schedule", "linear") + linear_start = sampling_settings.get("linear_start", 0.00085) + linear_end = sampling_settings.get("linear_end", 0.012) + + self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3) self.sigma_data = 1.0 def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index f2665565b..068cd6c95 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -6,7 +6,7 @@ import hashlib import math import random -from PIL import Image, ImageOps +from PIL import Image, ImageOps, ImageSequence from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch @@ -930,8 +930,8 @@ class GLIGENTextBoxApply: return (c, ) class EmptyLatentImage: - def __init__(self, device="cpu"): - self.device = device + def __init__(self): + self.device = comfy.model_management.intermediate_device() @classmethod def INPUT_TYPES(s): @@ -944,7 +944,7 @@ class EmptyLatentImage: CATEGORY = "latent" def generate(self, width, height, batch_size=1): - latent = torch.zeros([batch_size, 4, height // 8, width // 8]) + latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device) return ({"samples":latent}, ) @@ -1395,17 +1395,30 @@ class LoadImage: FUNCTION = "load_image" def load_image(self, image): image_path = folder_paths.get_annotated_filepath(image) - i = Image.open(image_path) - i = ImageOps.exif_transpose(i) - image = i.convert("RGB") - image = np.array(image).astype(np.float32) / 255.0 - image = torch.from_numpy(image)[None,] - if 'A' in i.getbands(): - mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 - mask = 1. - torch.from_numpy(mask) + img = Image.open(image_path) + output_images = [] + output_masks = [] + for i in ImageSequence.Iterator(img): + i = ImageOps.exif_transpose(i) + image = i.convert("RGB") + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + if 'A' in i.getbands(): + mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + else: + mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + output_images.append(image) + output_masks.append(mask.unsqueeze(0)) + + if len(output_images) > 1: + output_image = torch.cat(output_images, dim=0) + output_mask = torch.cat(output_masks, dim=0) else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") - return (image, mask.unsqueeze(0)) + output_image = output_images[0] + output_mask = output_masks[0] + + return (output_image, output_mask) @classmethod def IS_CHANGED(s, image): @@ -1463,13 +1476,10 @@ class LoadImageMask: return m.digest().hex() @classmethod - def VALIDATE_INPUTS(s, image, channel): + def VALIDATE_INPUTS(s, image): if not folder_paths.exists_annotated_filepath(image): return "Invalid image file: {}".format(image) - if channel not in s._color_channels: - return "Invalid color channel: {}".format(channel) - return True class ImageScale: diff --git a/comfy/ops.py b/comfy/ops.py index 0bfb698aa..f6f85de60 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1,40 +1,115 @@ import torch from contextlib import contextmanager +import comfy.model_management -class Linear(torch.nn.Linear): - def reset_parameters(self): - return None +def cast_bias_weight(s, input): + bias = None + non_blocking = comfy.model_management.device_supports_non_blocking(input.device) + if s.bias is not None: + bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + return weight, bias -class Conv2d(torch.nn.Conv2d): - def reset_parameters(self): - return None -class Conv3d(torch.nn.Conv3d): - def reset_parameters(self): - return None +class disable_weight_init: + class Linear(torch.nn.Linear): + comfy_cast_weights = False + def reset_parameters(self): + return None -def conv_nd(dims, *args, **kwargs): - if dims == 2: - return Conv2d(*args, **kwargs) - elif dims == 3: - return Conv3d(*args, **kwargs) - else: - raise ValueError(f"unsupported dimensions: {dims}") + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.linear(input, weight, bias) -@contextmanager -def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way - old_torch_nn_linear = torch.nn.Linear - force_device = device - force_dtype = dtype - def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None): - if force_device is not None: - device = force_device - if force_dtype is not None: - dtype = force_dtype - return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype) + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) - torch.nn.Linear = linear_with_dtype - try: - yield - finally: - torch.nn.Linear = old_torch_nn_linear + class Conv2d(torch.nn.Conv2d): + comfy_cast_weights = False + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return self._conv_forward(input, weight, bias) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + + class Conv3d(torch.nn.Conv3d): + comfy_cast_weights = False + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return self._conv_forward(input, weight, bias) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + + class GroupNorm(torch.nn.GroupNorm): + comfy_cast_weights = False + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + + + class LayerNorm(torch.nn.LayerNorm): + comfy_cast_weights = False + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + + @classmethod + def conv_nd(s, dims, *args, **kwargs): + if dims == 2: + return s.Conv2d(*args, **kwargs) + elif dims == 3: + return s.Conv3d(*args, **kwargs) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +class manual_cast(disable_weight_init): + class Linear(disable_weight_init.Linear): + comfy_cast_weights = True + + class Conv2d(disable_weight_init.Conv2d): + comfy_cast_weights = True + + class Conv3d(disable_weight_init.Conv3d): + comfy_cast_weights = True + + class GroupNorm(disable_weight_init.GroupNorm): + comfy_cast_weights = True + + class LayerNorm(disable_weight_init.LayerNorm): + comfy_cast_weights = True diff --git a/comfy/package_data_path_helper.py b/comfy/package_data_path_helper.py new file mode 100644 index 000000000..d4dea80a4 --- /dev/null +++ b/comfy/package_data_path_helper.py @@ -0,0 +1,9 @@ +from importlib.resources import path +import os + + +def get_editable_resource_path(caller_file, *package_path): + filename = os.path.join(os.path.dirname(os.path.realpath(caller_file)), package_path[-1]) + if not os.path.exists(filename): + filename = path(*package_path) + return filename diff --git a/comfy/sample.py b/comfy/sample.py index 197f2eb30..565ea662d 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -47,7 +47,8 @@ def convert_cond(cond): temp = c[1].copy() model_conds = temp.get("model_conds", {}) if c[0] is not None: - model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) + model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) #TODO: remove + temp["cross_attn"] = c[0] temp["model_conds"] = model_conds out.append(temp) return out @@ -98,10 +99,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative sampler = samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) - samples = samples.cpu() + samples = samples.to(model_management.intermediate_device()) cleanup_additional_models(models) - cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))) + cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) return samples def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): @@ -111,8 +112,8 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent sigmas = sigmas.to(model.load_device) samples = samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) - samples = samples.cpu() + samples = samples.to(model_management.intermediate_device()) cleanup_additional_models(models) - cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))) + cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) return samples diff --git a/comfy/samplers.py b/comfy/samplers.py index 6ddadd893..eeac3fbd5 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,256 +1,264 @@ from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc import torch +import collections from . import model_management import math +def get_area_and_mult(conds, x_in, timestep_in): + area = (x_in.shape[2], x_in.shape[3], 0, 0) + strength = 1.0 + + if 'timestep_start' in conds: + timestep_start = conds['timestep_start'] + if timestep_in[0] > timestep_start: + return None + if 'timestep_end' in conds: + timestep_end = conds['timestep_end'] + if timestep_in[0] < timestep_end: + return None + if 'area' in conds: + area = conds['area'] + if 'strength' in conds: + strength = conds['strength'] + + input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + if 'mask' in conds: + # Scale the mask to the size of the input + # The mask should have been resized as we began the sampling process + mask_strength = 1.0 + if "mask_strength" in conds: + mask_strength = conds["mask_strength"] + mask = conds['mask'] + assert(mask.shape[1] == x_in.shape[2]) + assert(mask.shape[2] == x_in.shape[3]) + mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength + mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) + else: + mask = torch.ones_like(input_x) + mult = mask * strength + + if 'mask' not in conds: + rr = 8 + if area[2] != 0: + for t in range(rr): + mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) + if (area[0] + area[2]) < x_in.shape[2]: + for t in range(rr): + mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) + if area[3] != 0: + for t in range(rr): + mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) + if (area[1] + area[3]) < x_in.shape[3]: + for t in range(rr): + mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) + + conditioning = {} + model_conds = conds["model_conds"] + for c in model_conds: + conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) + + control = conds.get('control', None) + + patches = None + if 'gligen' in conds: + gligen = conds['gligen'] + patches = {} + gligen_type = gligen[0] + gligen_model = gligen[1] + if gligen_type == "position": + gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device) + else: + gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device) + + patches['middle_patch'] = [gligen_patch] + + cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches']) + return cond_obj(input_x, mult, conditioning, area, control, patches) + +def cond_equal_size(c1, c2): + if c1 is c2: + return True + if c1.keys() != c2.keys(): + return False + for k in c1: + if not c1[k].can_concat(c2[k]): + return False + return True + +def can_concat_cond(c1, c2): + if c1.input_x.shape != c2.input_x.shape: + return False + + def objects_concatable(obj1, obj2): + if (obj1 is None) != (obj2 is None): + return False + if obj1 is not None: + if obj1 is not obj2: + return False + return True + + if not objects_concatable(c1.control, c2.control): + return False + + if not objects_concatable(c1.patches, c2.patches): + return False + + return cond_equal_size(c1.conditioning, c2.conditioning) + +def cond_cat(c_list): + c_crossattn = [] + c_concat = [] + c_adm = [] + crossattn_max_len = 0 + + temp = {} + for x in c_list: + for k in x: + cur = temp.get(k, []) + cur.append(x[k]) + temp[k] = cur + + out = {} + for k in temp: + conds = temp[k] + out[k] = conds[0].concat(conds[1:]) + + return out + +def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): + out_cond = torch.zeros_like(x_in) + out_count = torch.ones_like(x_in) * 1e-37 + + out_uncond = torch.zeros_like(x_in) + out_uncond_count = torch.ones_like(x_in) * 1e-37 + + COND = 0 + UNCOND = 1 + + to_run = [] + for x in cond: + p = get_area_and_mult(x, x_in, timestep) + if p is None: + continue + + to_run += [(p, COND)] + if uncond is not None: + for x in uncond: + p = get_area_and_mult(x, x_in, timestep) + if p is None: + continue + + to_run += [(p, UNCOND)] + + while len(to_run) > 0: + first = to_run[0] + first_shape = first[0][0].shape + to_batch_temp = [] + for x in range(len(to_run)): + if can_concat_cond(to_run[x][0], first[0]): + to_batch_temp += [x] + + to_batch_temp.reverse() + to_batch = to_batch_temp[:1] + + free_memory = model_management.get_free_memory(x_in.device) + for i in range(1, len(to_batch_temp) + 1): + batch_amount = to_batch_temp[:len(to_batch_temp)//i] + input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] + if model.memory_required(input_shape) < free_memory: + to_batch = batch_amount + break + + input_x = [] + mult = [] + c = [] + cond_or_uncond = [] + area = [] + control = None + patches = None + for x in to_batch: + o = to_run.pop(x) + p = o[0] + input_x.append(p.input_x) + mult.append(p.mult) + c.append(p.conditioning) + area.append(p.area) + cond_or_uncond.append(o[1]) + control = p.control + patches = p.patches + + batch_chunks = len(cond_or_uncond) + input_x = torch.cat(input_x) + c = cond_cat(c) + timestep_ = torch.cat([timestep] * batch_chunks) + + if control is not None: + c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) + + transformer_options = {} + if 'transformer_options' in model_options: + transformer_options = model_options['transformer_options'].copy() + + if patches is not None: + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] + else: + cur_patches[p] = patches[p] + else: + transformer_options["patches"] = patches + + transformer_options["cond_or_uncond"] = cond_or_uncond[:] + transformer_options["sigmas"] = timestep + + c['transformer_options'] = transformer_options + + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) + else: + output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) + del input_x + + for o in range(batch_chunks): + if cond_or_uncond[o] == COND: + out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] + out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] + else: + out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] + out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] + del mult + + out_cond /= out_count + del out_count + out_uncond /= out_uncond_count + del out_uncond_count + return out_cond, out_uncond #The main sampling function shared by all the samplers #Returns denoised def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): - def get_area_and_mult(conds, x_in, timestep_in): - area = (x_in.shape[2], x_in.shape[3], 0, 0) - strength = 1.0 - - if 'timestep_start' in conds: - timestep_start = conds['timestep_start'] - if timestep_in[0] > timestep_start: - return None - if 'timestep_end' in conds: - timestep_end = conds['timestep_end'] - if timestep_in[0] < timestep_end: - return None - if 'area' in conds: - area = conds['area'] - if 'strength' in conds: - strength = conds['strength'] - - input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - if 'mask' in conds: - # Scale the mask to the size of the input - # The mask should have been resized as we began the sampling process - mask_strength = 1.0 - if "mask_strength" in conds: - mask_strength = conds["mask_strength"] - mask = conds['mask'] - assert(mask.shape[1] == x_in.shape[2]) - assert(mask.shape[2] == x_in.shape[3]) - mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength - mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) - else: - mask = torch.ones_like(input_x) - mult = mask * strength - - if 'mask' not in conds: - rr = 8 - if area[2] != 0: - for t in range(rr): - mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) - if (area[0] + area[2]) < x_in.shape[2]: - for t in range(rr): - mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) - if area[3] != 0: - for t in range(rr): - mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) - if (area[1] + area[3]) < x_in.shape[3]: - for t in range(rr): - mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) - - conditionning = {} - model_conds = conds["model_conds"] - for c in model_conds: - conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) - - control = None - if 'control' in conds: - control = conds['control'] - - patches = None - if 'gligen' in conds: - gligen = conds['gligen'] - patches = {} - gligen_type = gligen[0] - gligen_model = gligen[1] - if gligen_type == "position": - gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device) - else: - gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device) - - patches['middle_patch'] = [gligen_patch] - - return (input_x, mult, conditionning, area, control, patches) - - def cond_equal_size(c1, c2): - if c1 is c2: - return True - if c1.keys() != c2.keys(): - return False - for k in c1: - if not c1[k].can_concat(c2[k]): - return False - return True - - def can_concat_cond(c1, c2): - if c1[0].shape != c2[0].shape: - return False - - #control - if (c1[4] is None) != (c2[4] is None): - return False - if c1[4] is not None: - if c1[4] is not c2[4]: - return False - - #patches - if (c1[5] is None) != (c2[5] is None): - return False - if (c1[5] is not None): - if c1[5] is not c2[5]: - return False - - return cond_equal_size(c1[2], c2[2]) - - def cond_cat(c_list): - c_crossattn = [] - c_concat = [] - c_adm = [] - crossattn_max_len = 0 - - temp = {} - for x in c_list: - for k in x: - cur = temp.get(k, []) - cur.append(x[k]) - temp[k] = cur - - out = {} - for k in temp: - conds = temp[k] - out[k] = conds[0].concat(conds[1:]) - - return out - - def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): - out_cond = torch.zeros_like(x_in) - out_count = torch.ones_like(x_in) * 1e-37 - - out_uncond = torch.zeros_like(x_in) - out_uncond_count = torch.ones_like(x_in) * 1e-37 - - COND = 0 - UNCOND = 1 - - to_run = [] - for x in cond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue - - to_run += [(p, COND)] - if uncond is not None: - for x in uncond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue - - to_run += [(p, UNCOND)] - - while len(to_run) > 0: - first = to_run[0] - first_shape = first[0][0].shape - to_batch_temp = [] - for x in range(len(to_run)): - if can_concat_cond(to_run[x][0], first[0]): - to_batch_temp += [x] - - to_batch_temp.reverse() - to_batch = to_batch_temp[:1] - - free_memory = model_management.get_free_memory(x_in.device) - for i in range(1, len(to_batch_temp) + 1): - batch_amount = to_batch_temp[:len(to_batch_temp)//i] - input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - if model.memory_required(input_shape) < free_memory: - to_batch = batch_amount - break - - input_x = [] - mult = [] - c = [] - cond_or_uncond = [] - area = [] - control = None - patches = None - for x in to_batch: - o = to_run.pop(x) - p = o[0] - input_x += [p[0]] - mult += [p[1]] - c += [p[2]] - area += [p[3]] - cond_or_uncond += [o[1]] - control = p[4] - patches = p[5] - - batch_chunks = len(cond_or_uncond) - input_x = torch.cat(input_x) - c = cond_cat(c) - timestep_ = torch.cat([timestep] * batch_chunks) - - if control is not None: - c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) - - transformer_options = {} - if 'transformer_options' in model_options: - transformer_options = model_options['transformer_options'].copy() - - if patches is not None: - if "patches" in transformer_options: - cur_patches = transformer_options["patches"].copy() - for p in patches: - if p in cur_patches: - cur_patches[p] = cur_patches[p] + patches[p] - else: - cur_patches[p] = patches[p] - else: - transformer_options["patches"] = patches - - transformer_options["cond_or_uncond"] = cond_or_uncond[:] - transformer_options["sigmas"] = timestep - - c['transformer_options'] = transformer_options - - if 'model_function_wrapper' in model_options: - output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) - else: - output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) - del input_x - - for o in range(batch_chunks): - if cond_or_uncond[o] == COND: - out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - else: - out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - del mult - - out_cond /= out_count - del out_count - out_uncond /= out_uncond_count - del out_uncond_count - return out_cond, out_uncond - - - if math.isclose(cond_scale, 1.0): - uncond = None - - cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) - if "sampler_cfg_function" in model_options: - args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} - return x - model_options["sampler_cfg_function"](args) + if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: + uncond_ = None else: - return uncond + (cond - uncond) * cond_scale + uncond_ = uncond + + cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) + if "sampler_cfg_function" in model_options: + args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, + "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} + cfg_result = x - model_options["sampler_cfg_function"](args) + else: + cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale + + for fn in model_options.get("sampler_post_cfg_function", []): + args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, + "sigma": timestep, "model_options": model_options, "input": x} + cfg_result = fn(args) + + return cfg_result class CFGNoisePredictor(torch.nn.Module): def __init__(self, model): @@ -272,10 +280,7 @@ class KSamplerX0Inpaint(torch.nn.Module): x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed) if denoise_mask is not None: - out *= denoise_mask - - if denoise_mask is not None: - out += self.latent_image * latent_mask + out = out * denoise_mask + self.latent_image * latent_mask return out def simple_scheduler(model, steps): @@ -590,6 +595,13 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model calculate_start_end_timesteps(model, negative) calculate_start_end_timesteps(model, positive) + if latent_image is not None: + latent_image = model.process_latent_in(latent_image) + + if hasattr(model, 'extra_conds'): + positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) + negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) + #make sure each cond area has an opposite one with the same area for c in positive: create_cond_with_same_area_if_none(negative, c) @@ -601,13 +613,6 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) - if latent_image is not None: - latent_image = model.process_latent_in(latent_image) - - if hasattr(model, 'extra_conds'): - positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) - negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) - extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) @@ -630,7 +635,7 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps): elif scheduler_name == "sgm_uniform": sigmas = normal_scheduler(model, steps, sgm=True) else: - print("error invalid scheduler", self.scheduler) + print("error invalid scheduler", scheduler_name) return sigmas def sampler_object(name): diff --git a/comfy/sd.py b/comfy/sd.py index cadce719e..c51172a46 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -148,12 +148,14 @@ class CLIP: return self.patcher.get_key_patches() class VAE: - def __init__(self, sd=None, device=None, config=None): + def __init__(self, sd=None, device=None, config=None, dtype=None): if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) + self.downscale_ratio = 8 + self.latent_channels = 4 if config is None: if "decoder.mid.block_1.mix_factor" in sd: @@ -169,6 +171,11 @@ class VAE: else: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} + + if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE + ddconfig['ch_mult'] = [1, 2, 4] + self.downscale_ratio = 4 + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) else: self.first_stage_model = AutoencoderKL(**(config['params'])) @@ -185,8 +192,11 @@ class VAE: device = model_management.vae_device() self.device = device offload_device = model_management.vae_offload_device() - self.vae_dtype = model_management.vae_dtype() + if dtype is None: + dtype = model_management.vae_dtype() + self.vae_dtype = dtype self.first_stage_model.to(self.vae_dtype) + self.output_device = model_management.intermediate_device() self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) @@ -198,9 +208,9 @@ class VAE: decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() output = torch.clamp(( - (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + - utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + - utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar)) + (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) + + utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) + + utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar)) / 3.0) / 2.0, min=0.0, max=1.0) return output @@ -211,9 +221,9 @@ class VAE: pbar = utils.ProgressBar(steps) encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() - samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) + samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) + samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples /= 3.0 return samples @@ -225,15 +235,15 @@ class VAE: batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu") + pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device) for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0) + pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0) except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) - pixel_samples = pixel_samples.cpu().movedim(1,-1) + pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): @@ -249,10 +259,10 @@ class VAE: free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") + samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device) for x in range(0, pixel_samples.shape[0], batch_number): pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) - samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float() + samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") @@ -429,11 +439,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o parameters = utils.calculate_parameters(sd, "model.diffusion_model.") unet_dtype = model_management.unet_dtype(model_params=parameters) + load_device = model_management.get_torch_device() + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) class WeightsLoader(torch.nn.Module): pass model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype) + model_config.set_manual_cast(manual_cast_dtype) + if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) @@ -466,7 +480,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o print("left over keys:", left_over) if output_model: - _model_patcher = model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) + _model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device) if inital_load_device != torch.device("cpu"): print("loaded straight to GPU") model_management.load_model_gpu(model_patcher) @@ -477,6 +491,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o def load_unet_state_dict(sd): #load unet in diffusers format parameters = utils.calculate_parameters(sd) unet_dtype = model_management.unet_dtype(model_params=parameters) + load_device = model_management.get_torch_device() + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) + if "input_blocks.0.0.weight" in sd: #ldm model_config = model_detection.model_config_from_unet(sd, "", unet_dtype) if model_config is None: @@ -497,13 +514,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format else: print(diffusers_keys[k], k) offload_device = model_management.unet_offload_device() + model_config.set_manual_cast(manual_cast_dtype) model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "") left_over = sd.keys() if len(left_over) > 0: print("left over keys in unet:", left_over) - return model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) + return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) def load_unet(unet_path): sd = utils.load_torch_file(unet_path) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 0ea9611dd..6722eb83f 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -1,6 +1,6 @@ import os -from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils +from transformers import CLIPTokenizer from . import ops import torch import traceback @@ -8,6 +8,8 @@ import zipfile from . import model_management from pkg_resources import resource_filename import contextlib +from . import clip_model +import json def gen_empty_tokens(special_tokens, length): start_token = special_tokens.get("start", None) @@ -38,7 +40,7 @@ class ClipTokenWeightEncoder: out, pooled = self.encode(to_encode) if pooled is not None: - first_pooled = pooled[0:1].cpu() + first_pooled = pooled[0:1].to(model_management.intermediate_device()) else: first_pooled = pooled @@ -55,8 +57,8 @@ class ClipTokenWeightEncoder: output.append(z) if (len(output) == 0): - return out[-1:].cpu(), first_pooled - return torch.cat(output, dim=-2).cpu(), first_pooled + return out[-1:].to(model_management.intermediate_device()), first_pooled + return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" @@ -66,33 +68,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): "hidden" ] def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, - freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None, - special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig, - model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32 + freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=clip_model.CLIPTextModel, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS - self.num_layers = 12 - if textmodel_path is not None: - self.transformer = model_class.from_pretrained(textmodel_path) - else: - if textmodel_json_config is None: - textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") - if not os.path.exists(textmodel_json_config): - textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json') - config = config_class.from_json_file(textmodel_json_config) - self.num_layers = config.num_hidden_layers - with ops.use_comfy_ops(device, dtype): - with modeling_utils.no_init_weights(): - self.transformer = model_class(config) - self.inner_name = inner_name - if dtype is not None: - self.transformer.to(dtype) - inner_model = getattr(self.transformer, self.inner_name) - if hasattr(inner_model, "embeddings"): - inner_model.embeddings.to(torch.float32) - else: - self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32)) + if textmodel_json_config is None: + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") + if not os.path.exists(textmodel_json_config): + textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json') + + with open(textmodel_json_config) as f: + config = json.load(f) + + self.transformer = model_class(config, dtype, device, ops.manual_cast) + self.num_layers = self.transformer.num_layers self.max_length = max_length if freeze: @@ -107,7 +97,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer_norm_hidden_state = layer_norm_hidden_state if layer == "hidden": assert layer_idx is not None - assert abs(layer_idx) <= self.num_layers + assert abs(layer_idx) < self.num_layers self.clip_layer(layer_idx) self.layer_default = (self.layer, self.layer_idx) @@ -118,7 +108,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): param.requires_grad = False def clip_layer(self, layer_idx): - if abs(layer_idx) >= self.num_layers: + if abs(layer_idx) > self.num_layers: self.layer = "last" else: self.layer = "hidden" @@ -173,41 +163,31 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = torch.LongTensor(tokens).to(device) - if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32: - precision_scope = torch.autocast + attention_mask = None + if self.enable_attention_masks: + attention_mask = torch.zeros_like(tokens) + max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 + for x in range(attention_mask.shape[0]): + for y in range(attention_mask.shape[1]): + attention_mask[x, y] = 1 + if tokens[x, y] == max_token: + break + + outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) + self.transformer.set_input_embeddings(backup_embeds) + + if self.layer == "last": + z = outputs[0] else: - precision_scope = lambda a, dtype: contextlib.nullcontext(a) + z = outputs[1] - with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32): - attention_mask = None - if self.enable_attention_masks: - attention_mask = torch.zeros_like(tokens) - max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 - for x in range(attention_mask.shape[0]): - for y in range(attention_mask.shape[1]): - attention_mask[x, y] = 1 - if tokens[x, y] == max_token: - break + if outputs[2] is not None: + pooled_output = outputs[2].float() + else: + pooled_output = None - outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden") - self.transformer.set_input_embeddings(backup_embeds) - - if self.layer == "last": - z = outputs.last_hidden_state - elif self.layer == "pooled": - z = outputs.pooler_output[:, None, :] - else: - z = outputs.hidden_states[self.layer_idx] - if self.layer_norm_hidden_state: - z = getattr(self.transformer, self.inner_name).final_layer_norm(z) - - if hasattr(outputs, "pooler_output"): - pooled_output = outputs.pooler_output.float() - else: - pooled_output = None - - if self.text_projection is not None and pooled_output is not None: - pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() + if self.text_projection is not None and pooled_output is not None: + pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() return z.float(), pooled_output def encode(self, tokens): diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index a33ce4210..db22e354d 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -4,15 +4,15 @@ from . import sd1_clip import os class SD2ClipHModel(sd1_clip.SDClipModel): - def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): + def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None): if layer == "penultimate": layer="hidden" - layer_idx=23 + layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") if not os.path.exists(textmodel_json_config): textmodel_json_config = resource_filename('comfy', 'sd2_clip_config.json') - super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}) + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}) class SD2ClipHTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 0a3d0530f..2d7a9e4e9 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -3,13 +3,13 @@ import torch import os class SDXLClipG(sd1_clip.SDClipModel): - def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): + def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None): if layer == "penultimate": layer="hidden" layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") - super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False) def load_sd(self, sd): @@ -37,7 +37,7 @@ class SDXLTokenizer: class SDXLClipModel(torch.nn.Module): def __init__(self, device="cpu", dtype=None): super().__init__() - self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False) + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False) self.clip_g = SDXLClipG(device=device, dtype=dtype) def clip_layer(self, layer_idx): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 455323b96..1d442d4dd 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -217,6 +217,16 @@ class SSD1B(SDXL): "use_temporal_attention": False, } +class Segmind_Vega(SDXL): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "transformer_depth": [0, 0, 1, 1, 2, 2], + "context_dim": 2048, + "adm_in_channels": 2816, + "use_temporal_attention": False, + } + class SVD_img2vid(supported_models_base.BASE): unet_config = { "model_channels": 320, @@ -242,5 +252,59 @@ class SVD_img2vid(supported_models_base.BASE): def clip_target(self): return None -models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B] +class Stable_Zero123(supported_models_base.BASE): + unet_config = { + "context_dim": 768, + "model_channels": 320, + "use_linear_in_transformer": False, + "adm_in_channels": None, + "use_temporal_attention": False, + "in_channels": 8, + } + + unet_extra_config = { + "num_heads": 8, + "num_head_channels": -1, + } + + clip_vision_prefix = "cond_stage_model.model.visual." + + latent_format = latent_formats.SD15 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"]) + return out + + def clip_target(self): + return None + +class SD_X4Upscaler(SD20): + unet_config = { + "context_dim": 1024, + "model_channels": 256, + 'in_channels': 7, + "use_linear_in_transformer": True, + "adm_in_channels": None, + "use_temporal_attention": False, + } + + unet_extra_config = { + "disable_self_attentions": [True, True, True, False], + "num_classes": 1000, + "num_heads": 8, + "num_head_channels": -1, + } + + latent_format = latent_formats.SD_X4 + + sampling_settings = { + "linear_start": 0.0001, + "linear_end": 0.02, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SD_X4Upscaler(self, device=device) + return out + +models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler] models += [SVD_img2vid] diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 3412cfea0..49087d23e 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -22,6 +22,8 @@ class BASE: sampling_settings = {} latent_format = latent_formats.LatentFormat + manual_cast_dtype = None + @classmethod def matches(s, unet_config): for k in s.unet_config: @@ -71,3 +73,5 @@ class BASE: replace_prefix = {"": "first_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) + def set_manual_cast(self, manual_cast_dtype): + self.manual_cast_dtype = manual_cast_dtype diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index a91da2f91..856f4f2cd 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -7,9 +7,10 @@ import torch import torch.nn as nn from .. import utils +from .. import ops def conv(n_in, n_out, **kwargs): - return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + return ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs) class Clamp(nn.Module): def forward(self, x): @@ -19,7 +20,7 @@ class Block(nn.Module): def __init__(self, n_in, n_out): super().__init__() self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) - self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.skip = ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() self.fuse = nn.ReLU() def forward(self, x): return self.fuse(self.conv(x) + self.skip(x)) diff --git a/comfy/utils.py b/comfy/utils.py index 0aa655264..8b71ed387 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -378,7 +378,7 @@ def lanczos(samples, width, height): images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] result = torch.stack(images) - return result + return result.to(samples.device, samples.dtype) def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": @@ -407,17 +407,17 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) @torch.inference_mode() -def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None): - output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu") +def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): + output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device) for b in range(samples.shape[0]): s = samples[b:b+1] - out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu") - out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu") + out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device) + out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device) for y in range(0, s.shape[2], tile_y - overlap): for x in range(0, s.shape[3], tile_x - overlap): s_in = s[:,:,y:y+tile_y,x:x+tile_x] - ps = function(s_in).cpu() + ps = function(s_in).to(output_device) mask = torch.ones_like(ps) feather = round(overlap * upscale_amount) for t in range(feather): diff --git a/comfy_extras/nodes/nodes_canny.py b/comfy_extras/nodes/nodes_canny.py index 94d453f2c..730dded5f 100644 --- a/comfy_extras/nodes/nodes_canny.py +++ b/comfy_extras/nodes/nodes_canny.py @@ -291,7 +291,7 @@ class Canny: def detect_edge(self, image, low_threshold, high_threshold): output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold) - img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1) + img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1) return (img_out,) NODE_CLASS_MAPPINGS = { diff --git a/comfy_extras/nodes/nodes_custom_sampler.py b/comfy_extras/nodes/nodes_custom_sampler.py index 04983f5be..baf72ecc5 100644 --- a/comfy_extras/nodes/nodes_custom_sampler.py +++ b/comfy_extras/nodes/nodes_custom_sampler.py @@ -1,9 +1,9 @@ -import comfy.samplers -import comfy.sample +from comfy import samplers +from comfy import sample from comfy.k_diffusion import sampling as k_diffusion_sampling from comfy.cmd import latent_preview import torch -import comfy.utils +from comfy import utils class BasicScheduler: @@ -11,8 +11,9 @@ class BasicScheduler: def INPUT_TYPES(s): return {"required": {"model": ("MODEL",), - "scheduler": (comfy.samplers.SCHEDULER_NAMES, ), + "scheduler": (samplers.SCHEDULER_NAMES, ), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), } } RETURN_TYPES = ("SIGMAS",) @@ -20,8 +21,15 @@ class BasicScheduler: FUNCTION = "get_sigmas" - def get_sigmas(self, model, scheduler, steps): - sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, steps).cpu() + def get_sigmas(self, model, scheduler, steps, denoise): + total_steps = steps + if denoise < 1.0: + total_steps = int(steps/denoise) + + inner_model = model.patch_model(patch_weights=False) + sigmas = samplers.calculate_sigmas_scheduler(inner_model, scheduler, total_steps).cpu() + model.unpatch_model() + sigmas = sigmas[-(steps + 1):] return (sigmas, ) @@ -87,6 +95,7 @@ class SDTurboScheduler: return {"required": {"model": ("MODEL",), "steps": ("INT", {"default": 1, "min": 1, "max": 10}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), } } RETURN_TYPES = ("SIGMAS",) @@ -94,9 +103,12 @@ class SDTurboScheduler: FUNCTION = "get_sigmas" - def get_sigmas(self, model, steps): - timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps] - sigmas = model.model.model_sampling.sigma(timesteps) + def get_sigmas(self, model, steps, denoise): + start_step = 10 - int(10 * denoise) + timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] + inner_model = model.patch_model(patch_weights=False) + sigmas = inner_model.model_sampling.sigma(timesteps) + model.unpatch_model() sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) return (sigmas, ) @@ -159,7 +171,7 @@ class KSamplerSelect: @classmethod def INPUT_TYPES(s): return {"required": - {"sampler_name": (comfy.samplers.SAMPLER_NAMES, ), + {"sampler_name": (samplers.SAMPLER_NAMES, ), } } RETURN_TYPES = ("SAMPLER",) @@ -168,7 +180,7 @@ class KSamplerSelect: FUNCTION = "get_sampler" def get_sampler(self, sampler_name): - sampler = comfy.samplers.sampler_object(sampler_name) + sampler = samplers.sampler_object(sampler_name) return (sampler, ) class SamplerDPMPP_2M_SDE: @@ -191,7 +203,7 @@ class SamplerDPMPP_2M_SDE: sampler_name = "dpmpp_2m_sde" else: sampler_name = "dpmpp_2m_sde_gpu" - sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type}) + sampler = samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type}) return (sampler, ) @@ -215,7 +227,7 @@ class SamplerDPMPP_SDE: sampler_name = "dpmpp_sde" else: sampler_name = "dpmpp_sde_gpu" - sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) + sampler = samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) return (sampler, ) class SamplerCustom: @@ -248,7 +260,7 @@ class SamplerCustom: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: batch_inds = latent["batch_index"] if "batch_index" in latent else None - noise = comfy.sample.prepare_noise(latent_image, noise_seed, batch_inds) + noise = sample.prepare_noise(latent_image, noise_seed, batch_inds) noise_mask = None if "noise_mask" in latent: @@ -257,8 +269,8 @@ class SamplerCustom: x0_output = {} callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output) - disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED - samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) + disable_pbar = not utils.PROGRESS_BAR_ENABLED + samples = sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) out = latent.copy() out["samples"] = samples diff --git a/comfy_extras/nodes/nodes_hypertile.py b/comfy_extras/nodes/nodes_hypertile.py index 0d7d4c954..e7446b2e5 100644 --- a/comfy_extras/nodes/nodes_hypertile.py +++ b/comfy_extras/nodes/nodes_hypertile.py @@ -2,9 +2,10 @@ import math from einops import rearrange -import random +# Use torch rng for consistency across generations +from torch import randint -def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int: +def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: min_value = min(min_value, value) # All big divisors of value (inclusive) @@ -12,8 +13,10 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter ns = [value // i for i in divisors[:max_options]] # has at least 1 element - random.seed(counter) - idx = random.randint(0, len(ns) - 1) + if len(ns) - 1 > 0: + idx = randint(low=0, high=len(ns) - 1, size=(1,)).item() + else: + idx = 0 return ns[idx] @@ -42,7 +45,6 @@ class HyperTile: latent_tile_size = max(32, tile_size) // 8 self.temp = None - self.counter = 1 def hypertile_in(q, k, v, extra_options): if q.shape[-1] in apply_to: @@ -53,10 +55,8 @@ class HyperTile: h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1 - nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter) - self.counter += 1 - nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter) - self.counter += 1 + nh = random_divisor(h, latent_tile_size * factor, swap_size) + nw = random_divisor(w, latent_tile_size * factor, swap_size) if nh * nw > 1: q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) diff --git a/comfy_extras/nodes/nodes_images.py b/comfy_extras/nodes/nodes_images.py index d82e0e250..b6202f181 100644 --- a/comfy_extras/nodes/nodes_images.py +++ b/comfy_extras/nodes/nodes_images.py @@ -73,7 +73,7 @@ class SaveAnimatedWEBP: OUTPUT_NODE = True - CATEGORY = "_for_testing" + CATEGORY = "image/animation" def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): method = self.methods.get(method) @@ -135,7 +135,7 @@ class SaveAnimatedPNG: OUTPUT_NODE = True - CATEGORY = "_for_testing" + CATEGORY = "image/animation" def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): filename_prefix += self.prefix_append diff --git a/comfy_extras/nodes/nodes_latent.py b/comfy_extras/nodes/nodes_latent.py index cedf39d63..2eefc4c55 100644 --- a/comfy_extras/nodes/nodes_latent.py +++ b/comfy_extras/nodes/nodes_latent.py @@ -3,9 +3,7 @@ import torch def reshape_latent_to(target_shape, latent): if latent.shape[1:] != target_shape[1:]: - latent.movedim(1, -1) latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center") - latent.movedim(-1, 1) return comfy.utils.repeat_to_batch_size(latent, target_shape[0]) @@ -102,9 +100,32 @@ class LatentInterpolate: samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) return (samples_out,) +class LatentBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "batch" + + CATEGORY = "latent/batch" + + def batch(self, samples1, samples2): + samples_out = samples1.copy() + s1 = samples1["samples"] + s2 = samples2["samples"] + + if s1.shape[1:] != s2.shape[1:]: + s2 = comfy.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center") + s = torch.cat((s1, s2), dim=0) + samples_out["samples"] = s + samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) + return (samples_out,) + NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, "LatentMultiply": LatentMultiply, "LatentInterpolate": LatentInterpolate, + "LatentBatch": LatentBatch, } diff --git a/comfy_extras/nodes/nodes_mask.py b/comfy_extras/nodes/nodes_mask.py index ae0f9d3b5..1ed455f04 100644 --- a/comfy_extras/nodes/nodes_mask.py +++ b/comfy_extras/nodes/nodes_mask.py @@ -7,6 +7,7 @@ from comfy.nodes.common import MAX_RESOLUTION def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): + source = source.to(destination.device) if resize_source: source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") @@ -21,7 +22,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou if mask is None: mask = torch.ones_like(source) else: - mask = mask.clone() + mask = mask.to(destination.device, copy=True) mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0]) diff --git a/comfy_extras/nodes/nodes_model_advanced.py b/comfy_extras/nodes/nodes_model_advanced.py index 6b95831bc..bbece153a 100644 --- a/comfy_extras/nodes/nodes_model_advanced.py +++ b/comfy_extras/nodes/nodes_model_advanced.py @@ -16,41 +16,19 @@ class LCM(comfy.model_sampling.EPS): return c_out * x0 + c_skip * model_input -class ModelSamplingDiscreteDistilled(torch.nn.Module): +class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete): original_timesteps = 50 - def __init__(self): - super().__init__() - self.sigma_data = 1.0 - timesteps = 1000 - beta_start = 0.00085 - beta_end = 0.012 + def __init__(self, model_config=None): + super().__init__(model_config) - betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2 - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) + self.skip_steps = self.num_timesteps // self.original_timesteps - self.skip_steps = timesteps // self.original_timesteps - - - alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32) + sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32) for x in range(self.original_timesteps): - alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] + sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps] - sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5 - self.set_sigmas(sigmas) - - def set_sigmas(self, sigmas): - self.register_buffer('sigmas', sigmas) - self.register_buffer('log_sigmas', sigmas.log()) - - @property - def sigma_min(self): - return self.sigmas[0] - - @property - def sigma_max(self): - return self.sigmas[-1] + self.set_sigmas(sigmas_valid) def timestep(self, sigma): log_sigma = sigma.log() @@ -65,14 +43,6 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module): log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] return log_sigma.exp().to(timestep.device) - def percent_to_sigma(self, percent): - if percent <= 0.0: - return 999999999.9 - if percent >= 1.0: - return 0.0 - percent = 1.0 - percent - return self.sigma(torch.tensor(percent * 999.0)).item() - def rescale_zero_terminal_snr_sigmas(sigmas): alphas_cumprod = 1 / ((sigmas * sigmas) + 1) @@ -121,7 +91,7 @@ class ModelSamplingDiscrete: class ModelSamplingAdvanced(sampling_base, sampling_type): pass - model_sampling = ModelSamplingAdvanced() + model_sampling = ModelSamplingAdvanced(model.model.model_config) if zsnr: model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas)) @@ -153,7 +123,7 @@ class ModelSamplingContinuousEDM: class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): pass - model_sampling = ModelSamplingAdvanced() + model_sampling = ModelSamplingAdvanced(model.model.model_config) model_sampling.set_sigma_range(sigma_min, sigma_max) m.add_object_patch("model_sampling", model_sampling) return (m, ) diff --git a/comfy_extras/nodes/nodes_perpneg.py b/comfy_extras/nodes/nodes_perpneg.py new file mode 100644 index 000000000..875937530 --- /dev/null +++ b/comfy_extras/nodes/nodes_perpneg.py @@ -0,0 +1,53 @@ +import torch +from comfy import sample +from comfy import samplers + + +class PerpNeg: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + "empty_conditioning": ("CONDITIONING", ), + "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, empty_conditioning, neg_scale): + m = model.clone() + nocond = sample.convert_cond(empty_conditioning) + + def cfg_function(args): + model = args["model"] + noise_pred_pos = args["cond_denoised"] + noise_pred_neg = args["uncond_denoised"] + cond_scale = args["cond_scale"] + x = args["input"] + sigma = args["sigma"] + model_options = args["model_options"] + nocond_processed = samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative") + + (noise_pred_nocond, _) = samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options) + + pos = noise_pred_pos - noise_pred_nocond + neg = noise_pred_neg - noise_pred_nocond + perp = ((torch.mul(pos, neg).sum())/(torch.norm(neg)**2)) * neg + perp_neg = perp * neg_scale + cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg) + cfg_result = x - cfg_result + return cfg_result + + m.set_model_sampler_cfg_function(cfg_function) + + return (m, ) + + +NODE_CLASS_MAPPINGS = { + "PerpNeg": PerpNeg, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "PerpNeg": "Perp-Neg", +} diff --git a/comfy_extras/nodes/nodes_post_processing.py b/comfy_extras/nodes/nodes_post_processing.py index 9213eb546..334214ad1 100644 --- a/comfy_extras/nodes/nodes_post_processing.py +++ b/comfy_extras/nodes/nodes_post_processing.py @@ -226,7 +226,7 @@ class Sharpen: batch_size, height, width, channels = image.shape kernel_size = sharpen_radius * 2 + 1 - kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10) + kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10) center = kernel_size // 2 kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) diff --git a/comfy_extras/nodes/nodes_rebatch.py b/comfy_extras/nodes/nodes_rebatch.py index 88a4ebe29..3010fbd4b 100644 --- a/comfy_extras/nodes/nodes_rebatch.py +++ b/comfy_extras/nodes/nodes_rebatch.py @@ -99,10 +99,40 @@ class LatentRebatch: return (output_list,) +class ImageRebatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "images": ("IMAGE",), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }} + RETURN_TYPES = ("IMAGE",) + INPUT_IS_LIST = True + OUTPUT_IS_LIST = (True, ) + + FUNCTION = "rebatch" + + CATEGORY = "image/batch" + + def rebatch(self, images, batch_size): + batch_size = batch_size[0] + + output_list = [] + all_images = [] + for img in images: + for i in range(img.shape[0]): + all_images.append(img[i:i+1]) + + for i in range(0, len(all_images), batch_size): + output_list.append(torch.cat(all_images[i:i+batch_size], dim=0)) + + return (output_list,) + NODE_CLASS_MAPPINGS = { "RebatchLatents": LatentRebatch, + "RebatchImages": ImageRebatch, } NODE_DISPLAY_NAME_MAPPINGS = { "RebatchLatents": "Rebatch Latents", -} \ No newline at end of file + "RebatchImages": "Rebatch Images", +} diff --git a/comfy_extras/nodes/nodes_sag.py b/comfy_extras/nodes/nodes_sag.py new file mode 100644 index 000000000..66606e328 --- /dev/null +++ b/comfy_extras/nodes/nodes_sag.py @@ -0,0 +1,168 @@ +import torch +from torch import einsum +import torch.nn.functional as F +import math + +from einops import rearrange, repeat +import os +from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION +from comfy import samplers + +# from comfy/ldm/modules/attention.py +# but modified to return attention scores as well as output +def attention_basic_with_sim(q, k, v, heads, mask=None): + b, _, dim_head = q.shape + dim_head //= heads + scale = dim_head ** -0.5 + + h = heads + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, -1, heads, dim_head) + .permute(0, 2, 1, 3) + .reshape(b * heads, -1, dim_head) + .contiguous(), + (q, k, v), + ) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION =="fp32": + sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * scale + + del q, k + + if mask is not None: + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) + out = ( + out.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + .permute(0, 2, 1, 3) + .reshape(b, -1, heads * dim_head) + ) + return (out, sim) + +def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): + # reshape and GAP the attention map + _, hw1, hw2 = attn.shape + b, _, lh, lw = x0.shape + attn = attn.reshape(b, -1, hw1, hw2) + # Global Average Pool + mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold + ratio = math.ceil(math.sqrt(lh * lw / hw1)) + mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)] + + # Reshape + mask = ( + mask.reshape(b, *mid_shape) + .unsqueeze(1) + .type(attn.dtype) + ) + # Upsample + mask = F.interpolate(mask, (lh, lw)) + + blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) + blurred = blurred * mask + x0 * (1 - mask) + return blurred + +def gaussian_blur_2d(img, kernel_size, sigma): + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + + x_kernel = pdf / pdf.sum() + x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) + + kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) + kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + + img = F.pad(img, padding, mode="reflect") + img = F.conv2d(img, kernel2d, groups=img.shape[-3]) + return img + +class SelfAttentionGuidance: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}), + "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, scale, blur_sigma): + m = model.clone() + + attn_scores = None + + # TODO: make this work properly with chunked batches + # currently, we can only save the attn from one UNet call + def attn_and_record(q, k, v, extra_options): + nonlocal attn_scores + # if uncond, save the attention scores + heads = extra_options["n_heads"] + cond_or_uncond = extra_options["cond_or_uncond"] + b = q.shape[0] // len(cond_or_uncond) + if 1 in cond_or_uncond: + uncond_index = cond_or_uncond.index(1) + # do the entire attention operation, but save the attention scores to attn_scores + (out, sim) = attention_basic_with_sim(q, k, v, heads=heads) + # when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn] + n_slices = heads * b + attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)] + return out + else: + return optimized_attention(q, k, v, heads=heads) + + def post_cfg_function(args): + nonlocal attn_scores + uncond_attn = attn_scores + + sag_scale = scale + sag_sigma = blur_sigma + sag_threshold = 1.0 + model = args["model"] + uncond_pred = args["uncond_denoised"] + uncond = args["uncond"] + cfg_result = args["denoised"] + sigma = args["sigma"] + model_options = args["model_options"] + x = args["input"] + + # create the adversarially blurred image + degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) + degraded_noised = degraded + x - uncond_pred + # call into the UNet + (sag, _) = samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options) + return cfg_result + (degraded - sag) * sag_scale + + m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True) + + # from diffusers: + # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch + m.set_model_attn1_replace(attn_and_record, "middle", 0, 0) + + return (m, ) + +NODE_CLASS_MAPPINGS = { + "SelfAttentionGuidance": SelfAttentionGuidance, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "SelfAttentionGuidance": "Self-Attention Guidance", +} diff --git a/comfy_extras/nodes/nodes_sdupscale.py b/comfy_extras/nodes/nodes_sdupscale.py new file mode 100644 index 000000000..4502fcd84 --- /dev/null +++ b/comfy_extras/nodes/nodes_sdupscale.py @@ -0,0 +1,46 @@ +import torch +from comfy import utils + +class SD_4XUpscale_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "images": ("IMAGE",), + "positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/upscale_diffusion" + + def encode(self, images, positive, negative, scale_ratio, noise_augmentation): + width = max(1, round(images.shape[-2] * scale_ratio)) + height = max(1, round(images.shape[-3] * scale_ratio)) + + pixels = utils.common_upscale((images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center") + + out_cp = [] + out_cn = [] + + for t in positive: + n = [t[0], t[1].copy()] + n[1]['concat_image'] = pixels + n[1]['noise_augmentation'] = noise_augmentation + out_cp.append(n) + + for t in negative: + n = [t[0], t[1].copy()] + n[1]['concat_image'] = pixels + n[1]['noise_augmentation'] = noise_augmentation + out_cn.append(n) + + latent = torch.zeros([images.shape[0], 4, height // 4, width // 4]) + return (out_cp, out_cn, {"samples":latent}) + +NODE_CLASS_MAPPINGS = { + "SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning, +} diff --git a/comfy_extras/nodes/nodes_stable3d.py b/comfy_extras/nodes/nodes_stable3d.py new file mode 100644 index 000000000..7aa6ec858 --- /dev/null +++ b/comfy_extras/nodes/nodes_stable3d.py @@ -0,0 +1,61 @@ +import torch +import comfy.utils + +from comfy.nodes.common import MAX_RESOLUTION +from comfy import utils + + +def camera_embeddings(elevation, azimuth): + elevation = torch.as_tensor([elevation]) + azimuth = torch.as_tensor([azimuth]) + embeddings = torch.stack( + [ + torch.deg2rad( + (90 - elevation) - (90) + ), # Zero123 polar is 90-elevation + torch.sin(torch.deg2rad(azimuth)), + torch.cos(torch.deg2rad(azimuth)), + torch.deg2rad( + 90 - torch.full_like(elevation, 0) + ), + ], dim=-1).unsqueeze(1) + + return embeddings + + +class StableZero123_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/3d_models" + + def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + t = vae.encode(encode_pixels) + cam_embeds = camera_embeddings(elevation, azimuth) + cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1) + + positive = [[cond, {"concat_latent_image": t}]] + negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] + latent = torch.zeros([batch_size, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent}) + +NODE_CLASS_MAPPINGS = { + "StableZero123_Conditioning": StableZero123_Conditioning, +} diff --git a/requirements.txt b/requirements.txt index 9bacddc13..df6d90d44 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ torchaudio torchvision torchdiffeq>=0.2.3 torchsde>=0.2.6 +torchvision einops>=0.6.0 open-clip-torch>=2.16.0 transformers>=4.29.1 diff --git a/setup.py b/setup.py index e69d6146c..ca79d0db8 100644 --- a/setup.py +++ b/setup.py @@ -28,18 +28,18 @@ version = '0.0.1' """ The package index to the torch built with AMD ROCm. """ -amd_torch_index = "https://download.pytorch.org/whl/rocm5.6" +amd_torch_index = ("https://download.pytorch.org/whl/rocm5.6", "https://download.pytorch.org/whl/nightly/rocm5.7") """ The package index to torch built with CUDA. Observe the CUDA version is in this URL. """ -nvidia_torch_index = "https://download.pytorch.org/whl/cu121" +nvidia_torch_index = ("https://download.pytorch.org/whl/cu121", "https://download.pytorch.org/whl/nightly/cu121") """ The package index to torch built against CPU features. """ -cpu_torch_index = "https://download.pytorch.org/whl/cpu" +cpu_torch_index = ("https://download.pytorch.org/whl/cpu", "https://download.pytorch.org/whl/nightly/cpu") # xformers not required for new torch @@ -102,11 +102,11 @@ def _is_linux_arm64(): def dependencies() -> List[str]: _dependencies = open(os.path.join(os.path.dirname(__file__), "requirements.txt")).readlines() - # todo: also add all plugin dependencies _alternative_indices = [amd_torch_index, nvidia_torch_index] session = PipSession() - index_urls = ['https://pypi.org/simple'] + # (stable, nightly) tuple + index_urls = [('https://pypi.org/simple', 'https://pypi.org/simple')] # prefer nvidia over AMD because AM5/iGPU systems will have a valid ROCm device if _is_nvidia(): index_urls += [nvidia_torch_index] @@ -118,6 +118,13 @@ def dependencies() -> List[str]: if len(index_urls) == 1: return _dependencies + if sys.version_info >= (3, 12): + # use the nightlies + index_urls = [nightly for (_, nightly) in index_urls] + _alternative_indices = [nightly for (_, nightly) in _alternative_indices] + else: + index_urls = [stable for (stable, _) in index_urls] + _alternative_indices = [stable for (stable, _) in _alternative_indices] try: # pip 23 finder = PackageFinder.create(LinkCollector(session, SearchScope([], index_urls, no_index=False)), @@ -149,7 +156,7 @@ setup( description="", author="", version=version, - python_requires=">=3.9,<3.12", + python_requires=">=3.9,<3.13", # todo: figure out how to include the web directory to eventually let main live inside the package # todo: see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ for more about adding plugins packages=find_packages(exclude=[] if is_editable else ['custom_nodes']), diff --git a/tests-ui/afterSetup.js b/tests-ui/afterSetup.js new file mode 100644 index 000000000..983f3af64 --- /dev/null +++ b/tests-ui/afterSetup.js @@ -0,0 +1,9 @@ +const { start } = require("./utils"); +const lg = require("./utils/litegraph"); + +// Load things once per test file before to ensure its all warmed up for the tests +beforeAll(async () => { + lg.setup(global); + await start({ resetEnv: true }); + lg.teardown(global); +}); diff --git a/tests-ui/jest.config.js b/tests-ui/jest.config.js index b5a5d646d..86fff5057 100644 --- a/tests-ui/jest.config.js +++ b/tests-ui/jest.config.js @@ -2,8 +2,10 @@ const config = { testEnvironment: "jsdom", setupFiles: ["./globalSetup.js"], + setupFilesAfterEnv: ["./afterSetup.js"], clearMocks: true, resetModules: true, + testTimeout: 10000 }; module.exports = config; diff --git a/tests-ui/tests/extensions.test.js b/tests-ui/tests/extensions.test.js index b82e55c32..159e5113a 100644 --- a/tests-ui/tests/extensions.test.js +++ b/tests-ui/tests/extensions.test.js @@ -52,7 +52,7 @@ describe("extensions", () => { const nodeNames = Object.keys(defs); const nodeCount = nodeNames.length; expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount); - for (let i = 0; i < nodeCount; i++) { + for (let i = 0; i < 10; i++) { // It should be send the JS class and the original JSON definition const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0]; const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1]; @@ -133,7 +133,7 @@ describe("extensions", () => { expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2); expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1); expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2); - }); + }, 15000); it("allows custom nodeDefs and widgets to be registered", async () => { const widgetMock = jest.fn((node, inputName, inputData, app) => { diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js index ce54c1154..e6ebedd91 100644 --- a/tests-ui/tests/groupNode.test.js +++ b/tests-ui/tests/groupNode.test.js @@ -1,7 +1,7 @@ // @ts-check /// -const { start, createDefaultWorkflow } = require("../utils"); +const { start, createDefaultWorkflow, getNodeDef, checkBeforeAndAfterReload } = require("../utils"); const lg = require("../utils/litegraph"); describe("group node", () => { @@ -273,7 +273,7 @@ describe("group node", () => { let reroutes = []; let prevNode = nodes.ckpt; - for(let i = 0; i < 5; i++) { + for (let i = 0; i < 5; i++) { const reroute = ez.Reroute(); prevNode.outputs[0].connectTo(reroute.inputs[0]); prevNode = reroute; @@ -283,7 +283,7 @@ describe("group node", () => { const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]); expect((await graph.toPrompt()).output).toEqual(getOutput()); - + group.menu["Convert to nodes"].call(); expect((await graph.toPrompt()).output).toEqual(getOutput()); }); @@ -383,6 +383,43 @@ describe("group node", () => { getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id]) ); }); + test("groups can connect to each other via internal reroutes", async () => { + const { ez, graph, app } = await start(); + + const latent = ez.EmptyLatentImage(); + const vae = ez.VAELoader(); + const latentReroute = ez.Reroute(); + const vaeReroute = ez.Reroute(); + + latent.outputs[0].connectTo(latentReroute.inputs[0]); + vae.outputs[0].connectTo(vaeReroute.inputs[0]); + + const group1 = await convertToGroup(app, graph, "test", [latentReroute, vaeReroute]); + group1.menu.Clone.call(); + expect(app.graph._nodes).toHaveLength(4); + const group2 = graph.find(app.graph._nodes[3]); + expect(group2.node.type).toEqual("workflow/test"); + expect(group2.id).not.toEqual(group1.id); + + group1.outputs.VAE.connectTo(group2.inputs.VAE); + group1.outputs.LATENT.connectTo(group2.inputs.LATENT); + + const decode = ez.VAEDecode(group2.outputs.LATENT, group2.outputs.VAE); + const preview = ez.PreviewImage(decode.outputs[0]); + + const output = { + [latent.id]: { inputs: { width: 512, height: 512, batch_size: 1 }, class_type: "EmptyLatentImage" }, + [vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" }, + [decode.id]: { inputs: { samples: [latent.id + "", 0], vae: [vae.id + "", 0] }, class_type: "VAEDecode" }, + [preview.id]: { inputs: { images: [decode.id + "", 0] }, class_type: "PreviewImage" }, + }; + expect((await graph.toPrompt()).output).toEqual(output); + + // Ensure missing connections dont cause errors + group2.inputs.VAE.disconnect(); + delete output[decode.id].inputs.vae; + expect((await graph.toPrompt()).output).toEqual(output); + }); test("displays generated image on group node", async () => { const { ez, graph, app } = await start(); const nodes = createDefaultWorkflow(ez, graph); @@ -642,6 +679,55 @@ describe("group node", () => { 2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" }, }); }); + test("correctly handles widget inputs", async () => { + const { ez, graph, app } = await start(); + const upscaleMethods = (await getNodeDef("ImageScaleBy")).input.required["upscale_method"][0]; + + const image = ez.LoadImage(); + const scale1 = ez.ImageScaleBy(image.outputs[0]); + const scale2 = ez.ImageScaleBy(image.outputs[0]); + const preview1 = ez.PreviewImage(scale1.outputs[0]); + const preview2 = ez.PreviewImage(scale2.outputs[0]); + scale1.widgets.upscale_method.value = upscaleMethods[1]; + scale1.widgets.upscale_method.convertToInput(); + + const group = await convertToGroup(app, graph, "test", [scale1, scale2]); + expect(group.inputs.length).toBe(3); + expect(group.inputs[0].input.type).toBe("IMAGE"); + expect(group.inputs[1].input.type).toBe("IMAGE"); + expect(group.inputs[2].input.type).toBe("COMBO"); + + // Ensure links are maintained + expect(group.inputs[0].connection?.originNode?.id).toBe(image.id); + expect(group.inputs[1].connection?.originNode?.id).toBe(image.id); + expect(group.inputs[2].connection).toBeFalsy(); + + // Ensure primitive gets correct type + const primitive = ez.PrimitiveNode(); + primitive.outputs[0].connectTo(group.inputs[2]); + expect(primitive.widgets.value.widget.options.values).toBe(upscaleMethods); + expect(primitive.widgets.value.value).toBe(upscaleMethods[1]); // Ensure value is copied + primitive.widgets.value.value = upscaleMethods[1]; + + await checkBeforeAndAfterReload(graph, async (r) => { + const scale1id = r ? `${group.id}:0` : scale1.id; + const scale2id = r ? `${group.id}:1` : scale2.id; + // Ensure widget value is applied to prompt + expect((await graph.toPrompt()).output).toStrictEqual({ + [image.id]: { inputs: { image: "example.png", upload: "image" }, class_type: "LoadImage" }, + [scale1id]: { + inputs: { upscale_method: upscaleMethods[1], scale_by: 1, image: [`${image.id}`, 0] }, + class_type: "ImageScaleBy", + }, + [scale2id]: { + inputs: { upscale_method: "nearest-exact", scale_by: 1, image: [`${image.id}`, 0] }, + class_type: "ImageScaleBy", + }, + [preview1.id]: { inputs: { images: [`${scale1id}`, 0] }, class_type: "PreviewImage" }, + [preview2.id]: { inputs: { images: [`${scale2id}`, 0] }, class_type: "PreviewImage" }, + }); + }); + }); test("adds widgets in node execution order", async () => { const { ez, graph, app } = await start(); const scale = ez.LatentUpscale(); @@ -815,4 +901,105 @@ describe("group node", () => { expect(p2.widgets.control_after_generate.value).toBe("randomize"); expect(p2.widgets.control_filter_list.value).toBe("/.+/"); }); + test("internal reroutes work with converted inputs and merge options", async () => { + const { ez, graph, app } = await start(); + const vae = ez.VAELoader(); + const latent = ez.EmptyLatentImage(); + const decode = ez.VAEDecode(latent.outputs.LATENT, vae.outputs.VAE); + const scale = ez.ImageScale(decode.outputs.IMAGE); + ez.PreviewImage(scale.outputs.IMAGE); + + const r1 = ez.Reroute(); + const r2 = ez.Reroute(); + + latent.widgets.width.value = 64; + latent.widgets.height.value = 128; + + latent.widgets.width.convertToInput(); + latent.widgets.height.convertToInput(); + latent.widgets.batch_size.convertToInput(); + + scale.widgets.width.convertToInput(); + scale.widgets.height.convertToInput(); + + r1.inputs[0].input.label = "hbw"; + r1.outputs[0].connectTo(latent.inputs.height); + r1.outputs[0].connectTo(latent.inputs.batch_size); + r1.outputs[0].connectTo(scale.inputs.width); + + r2.inputs[0].input.label = "wh"; + r2.outputs[0].connectTo(latent.inputs.width); + r2.outputs[0].connectTo(scale.inputs.height); + + const group = await convertToGroup(app, graph, "test", [r1, r2, latent, decode, scale]); + + expect(group.inputs[0].input.type).toBe("VAE"); + expect(group.inputs[1].input.type).toBe("INT"); + expect(group.inputs[2].input.type).toBe("INT"); + + const p1 = ez.PrimitiveNode(); + const p2 = ez.PrimitiveNode(); + p1.outputs[0].connectTo(group.inputs[1]); + p2.outputs[0].connectTo(group.inputs[2]); + + expect(p1.widgets.value.widget.options?.min).toBe(16); // width/height min + expect(p1.widgets.value.widget.options?.max).toBe(4096); // batch max + expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 + + expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min + expect(p2.widgets.value.widget.options?.max).toBe(8192); // width/height max + expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 + + expect(p1.widgets.value.value).toBe(128); + expect(p2.widgets.value.value).toBe(64); + + p1.widgets.value.value = 16; + p2.widgets.value.value = 32; + + await checkBeforeAndAfterReload(graph, async (r) => { + const id = (v) => (r ? `${group.id}:` : "") + v; + expect((await graph.toPrompt()).output).toStrictEqual({ + 1: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" }, + [id(2)]: { inputs: { width: 32, height: 16, batch_size: 16 }, class_type: "EmptyLatentImage" }, + [id(3)]: { inputs: { samples: [id(2), 0], vae: ["1", 0] }, class_type: "VAEDecode" }, + [id(4)]: { + inputs: { upscale_method: "nearest-exact", width: 16, height: 32, crop: "disabled", image: [id(3), 0] }, + class_type: "ImageScale", + }, + 5: { inputs: { images: [id(4), 0] }, class_type: "PreviewImage" }, + }); + }); + }); + test("converted inputs with linked widgets map values correctly on creation", async () => { + const { ez, graph, app } = await start(); + const k1 = ez.KSampler(); + const k2 = ez.KSampler(); + k1.widgets.seed.convertToInput(); + k2.widgets.seed.convertToInput(); + + const rr = ez.Reroute(); + rr.outputs[0].connectTo(k1.inputs.seed); + rr.outputs[0].connectTo(k2.inputs.seed); + + const group = await convertToGroup(app, graph, "test", [k1, k2, rr]); + expect(group.widgets.steps.value).toBe(20); + expect(group.widgets.cfg.value).toBe(8); + expect(group.widgets.scheduler.value).toBe("normal"); + expect(group.widgets["KSampler steps"].value).toBe(20); + expect(group.widgets["KSampler cfg"].value).toBe(8); + expect(group.widgets["KSampler scheduler"].value).toBe("normal"); + }); + test("allow multiple of the same node type to be added", async () => { + const { ez, graph, app } = await start(); + const nodes = [...Array(10)].map(() => ez.ImageScaleBy()); + const group = await convertToGroup(app, graph, "test", nodes); + expect(group.inputs.length).toBe(10); + expect(group.outputs.length).toBe(10); + expect(group.widgets.length).toBe(20); + expect(group.widgets.map((w) => w.widget.name)).toStrictEqual( + [...Array(10)] + .map((_, i) => `${i > 0 ? "ImageScaleBy " : ""}${i > 1 ? i + " " : ""}`) + .flatMap((p) => [`${p}upscale_method`, `${p}scale_by`]) + ); + }); }); diff --git a/tests-ui/tests/widgetInputs.test.js b/tests-ui/tests/widgetInputs.test.js index 8e191adf0..67e3fa341 100644 --- a/tests-ui/tests/widgetInputs.test.js +++ b/tests-ui/tests/widgetInputs.test.js @@ -1,7 +1,13 @@ // @ts-check /// -const { start, makeNodeDef, checkBeforeAndAfterReload, assertNotNullOrUndefined } = require("../utils"); +const { + start, + makeNodeDef, + checkBeforeAndAfterReload, + assertNotNullOrUndefined, + createDefaultWorkflow, +} = require("../utils"); const lg = require("../utils/litegraph"); /** @@ -36,7 +42,7 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWi if (controlWidgetCount) { const controlWidget = primitive.widgets.control_after_generate; expect(controlWidget.widget.type).toBe("combo"); - if(widgetType === "combo") { + if (widgetType === "combo") { const filterWidget = primitive.widgets.control_filter_list; expect(filterWidget.widget.type).toBe("string"); } @@ -308,8 +314,8 @@ describe("widget inputs", () => { const { ez } = await start({ mockNodeDefs: { ...makeNodeDef("TestNode1", {}, [["A", "B"]]), - ...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true}] }), - ...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true}] }), + ...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true }] }), + ...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true }] }), }, }); @@ -330,7 +336,7 @@ describe("widget inputs", () => { const n1 = ez.TestNode1(); n1.widgets.example.convertToInput(); - const p = ez.PrimitiveNode() + const p = ez.PrimitiveNode(); p.outputs[0].connectTo(n1.inputs[0]); const value = p.widgets.value; @@ -380,7 +386,7 @@ describe("widget inputs", () => { // Check random control.value = "randomize"; filter.value = "/D/"; - for(let i = 0; i < 100; i++) { + for (let i = 0; i < 100; i++) { control["afterQueued"](); expect(value.value === "D" || value.value === "DD").toBeTruthy(); } @@ -392,4 +398,160 @@ describe("widget inputs", () => { control["afterQueued"](); expect(value.value).toBe("B"); }); + + describe("reroutes", () => { + async function checkOutput(graph, values) { + expect((await graph.toPrompt()).output).toStrictEqual({ + 1: { inputs: { ckpt_name: "model1.safetensors" }, class_type: "CheckpointLoaderSimple" }, + 2: { inputs: { text: "positive", clip: ["1", 1] }, class_type: "CLIPTextEncode" }, + 3: { inputs: { text: "negative", clip: ["1", 1] }, class_type: "CLIPTextEncode" }, + 4: { + inputs: { width: values.width ?? 512, height: values.height ?? 512, batch_size: values?.batch_size ?? 1 }, + class_type: "EmptyLatentImage", + }, + 5: { + inputs: { + seed: 0, + steps: 20, + cfg: 8, + sampler_name: "euler", + scheduler: values?.scheduler ?? "normal", + denoise: 1, + model: ["1", 0], + positive: ["2", 0], + negative: ["3", 0], + latent_image: ["4", 0], + }, + class_type: "KSampler", + }, + 6: { inputs: { samples: ["5", 0], vae: ["1", 2] }, class_type: "VAEDecode" }, + 7: { + inputs: { filename_prefix: values.filename_prefix ?? "ComfyUI", images: ["6", 0] }, + class_type: "SaveImage", + }, + }); + } + + async function waitForWidget(node) { + // widgets are created slightly after the graph is ready + // hard to find an exact hook to get these so just wait for them to be ready + for (let i = 0; i < 10; i++) { + await new Promise((r) => setTimeout(r, 10)); + if (node.widgets?.value) { + return; + } + } + } + + it("can connect primitive via a reroute path to a widget input", async () => { + const { ez, graph } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + + nodes.empty.widgets.width.convertToInput(); + nodes.sampler.widgets.scheduler.convertToInput(); + nodes.save.widgets.filename_prefix.convertToInput(); + + let widthReroute = ez.Reroute(); + let schedulerReroute = ez.Reroute(); + let fileReroute = ez.Reroute(); + + let widthNext = widthReroute; + let schedulerNext = schedulerReroute; + let fileNext = fileReroute; + + for (let i = 0; i < 5; i++) { + let next = ez.Reroute(); + widthNext.outputs[0].connectTo(next.inputs[0]); + widthNext = next; + + next = ez.Reroute(); + schedulerNext.outputs[0].connectTo(next.inputs[0]); + schedulerNext = next; + + next = ez.Reroute(); + fileNext.outputs[0].connectTo(next.inputs[0]); + fileNext = next; + } + + widthNext.outputs[0].connectTo(nodes.empty.inputs.width); + schedulerNext.outputs[0].connectTo(nodes.sampler.inputs.scheduler); + fileNext.outputs[0].connectTo(nodes.save.inputs.filename_prefix); + + let widthPrimitive = ez.PrimitiveNode(); + let schedulerPrimitive = ez.PrimitiveNode(); + let filePrimitive = ez.PrimitiveNode(); + + widthPrimitive.outputs[0].connectTo(widthReroute.inputs[0]); + schedulerPrimitive.outputs[0].connectTo(schedulerReroute.inputs[0]); + filePrimitive.outputs[0].connectTo(fileReroute.inputs[0]); + expect(widthPrimitive.widgets.value.value).toBe(512); + widthPrimitive.widgets.value.value = 1024; + expect(schedulerPrimitive.widgets.value.value).toBe("normal"); + schedulerPrimitive.widgets.value.value = "simple"; + expect(filePrimitive.widgets.value.value).toBe("ComfyUI"); + filePrimitive.widgets.value.value = "ComfyTest"; + + await checkBeforeAndAfterReload(graph, async () => { + widthPrimitive = graph.find(widthPrimitive); + schedulerPrimitive = graph.find(schedulerPrimitive); + filePrimitive = graph.find(filePrimitive); + await waitForWidget(filePrimitive); + expect(widthPrimitive.widgets.length).toBe(2); + expect(schedulerPrimitive.widgets.length).toBe(3); + expect(filePrimitive.widgets.length).toBe(1); + + await checkOutput(graph, { + width: 1024, + scheduler: "simple", + filename_prefix: "ComfyTest", + }); + }); + }); + it("can connect primitive via a reroute path to multiple widget inputs", async () => { + const { ez, graph } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + + nodes.empty.widgets.width.convertToInput(); + nodes.empty.widgets.height.convertToInput(); + nodes.empty.widgets.batch_size.convertToInput(); + + let reroute = ez.Reroute(); + let prevReroute = reroute; + for (let i = 0; i < 5; i++) { + const next = ez.Reroute(); + prevReroute.outputs[0].connectTo(next.inputs[0]); + prevReroute = next; + } + + const r1 = ez.Reroute(prevReroute.outputs[0]); + const r2 = ez.Reroute(prevReroute.outputs[0]); + const r3 = ez.Reroute(r2.outputs[0]); + const r4 = ez.Reroute(r2.outputs[0]); + + r1.outputs[0].connectTo(nodes.empty.inputs.width); + r3.outputs[0].connectTo(nodes.empty.inputs.height); + r4.outputs[0].connectTo(nodes.empty.inputs.batch_size); + + let primitive = ez.PrimitiveNode(); + primitive.outputs[0].connectTo(reroute.inputs[0]); + expect(primitive.widgets.value.value).toBe(1); + primitive.widgets.value.value = 64; + + await checkBeforeAndAfterReload(graph, async (r) => { + primitive = graph.find(primitive); + await waitForWidget(primitive); + + // Ensure widget configs are merged + expect(primitive.widgets.value.widget.options?.min).toBe(16); // width/height min + expect(primitive.widgets.value.widget.options?.max).toBe(4096); // batch max + expect(primitive.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 + + await checkOutput(graph, { + width: 64, + height: 64, + batch_size: 64, + }); + }); + }); + }); }); diff --git a/tests-ui/utils/ezgraph.js b/tests-ui/utils/ezgraph.js index 898b82db0..8a55246ee 100644 --- a/tests-ui/utils/ezgraph.js +++ b/tests-ui/utils/ezgraph.js @@ -78,6 +78,14 @@ export class EzInput extends EzSlot { this.input = input; } + get connection() { + const link = this.node.node.inputs?.[this.index]?.link; + if (link == null) { + return null; + } + return new EzConnection(this.node.app, this.node.app.graph.links[link]); + } + disconnect() { this.node.node.disconnectInput(this.index); } @@ -117,7 +125,7 @@ export class EzOutput extends EzSlot { const inp = input.input; const inName = inp.name || inp.label || inp.type; throw new Error( - `Connecting from ${input.node.node.type}[${inName}#${input.index}] -> ${this.node.node.type}[${ + `Connecting from ${input.node.node.type}#${input.node.id}[${inName}#${input.index}] -> ${this.node.node.type}#${this.node.id}[${ this.output.name ?? this.output.type }#${this.index}] failed.` ); @@ -179,6 +187,7 @@ export class EzWidget { set value(v) { this.widget.value = v; + this.widget.callback?.call?.(this.widget, v) } get isConvertedToInput() { @@ -319,7 +328,7 @@ export class EzGraph { } stringify() { - return JSON.stringify(this.app.graph.serialize(), undefined, "\t"); + return JSON.stringify(this.app.graph.serialize(), undefined); } /** diff --git a/tests-ui/utils/index.js b/tests-ui/utils/index.js index 3a018f566..6a08e8594 100644 --- a/tests-ui/utils/index.js +++ b/tests-ui/utils/index.js @@ -104,3 +104,12 @@ export function createDefaultWorkflow(ez, graph) { return { ckpt, pos, neg, empty, sampler, decode, save }; } + +export async function getNodeDefs() { + const { api } = require("../../web/scripts/api"); + return api.getNodeDefs(); +} + +export async function getNodeDef(nodeId) { + return (await getNodeDefs())[nodeId]; +} \ No newline at end of file diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index 6766f356d..4cf1f7621 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -174,6 +174,11 @@ export class GroupNodeConfig { node.index = i; this.processNode(node, seenInputs, seenOutputs); } + + for (const p of this.#convertedToProcess) { + p(); + } + this.#convertedToProcess = null; await app.registerNodeDef("workflow/" + this.name, this.nodeDef); } @@ -192,7 +197,10 @@ export class GroupNodeConfig { if (!this.linksFrom[sourceNodeId]) { this.linksFrom[sourceNodeId] = {}; } - this.linksFrom[sourceNodeId][sourceNodeSlot] = l; + if (!this.linksFrom[sourceNodeId][sourceNodeSlot]) { + this.linksFrom[sourceNodeId][sourceNodeSlot] = []; + } + this.linksFrom[sourceNodeId][sourceNodeSlot].push(l); if (!this.linksTo[targetNodeId]) { this.linksTo[targetNodeId] = {}; @@ -230,11 +238,11 @@ export class GroupNodeConfig { // Skip as its not linked if (!linksFrom) return; - let type = linksFrom["0"][5]; + let type = linksFrom["0"][0][5]; if (type === "COMBO") { // Use the array items const source = node.outputs[0].widget.name; - const fromTypeName = this.nodeData.nodes[linksFrom["0"][2]].type; + const fromTypeName = this.nodeData.nodes[linksFrom["0"][0][2]].type; const fromType = globalDefs[fromTypeName]; const input = fromType.input.required[source] ?? fromType.input.optional[source]; type = input[0]; @@ -258,10 +266,33 @@ export class GroupNodeConfig { return null; } + let config = {}; let rerouteType = "*"; if (linksFrom) { - const [, , id, slot] = linksFrom["0"]; - rerouteType = this.nodeData.nodes[id].inputs[slot].type; + for (const [, , id, slot] of linksFrom["0"]) { + const node = this.nodeData.nodes[id]; + const input = node.inputs[slot]; + if (rerouteType === "*") { + rerouteType = input.type; + } + if (input.widget) { + const targetDef = globalDefs[node.type]; + const targetWidget = + targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name]; + + const widget = [targetWidget[0], config]; + const res = mergeIfValid( + { + widget, + }, + targetWidget, + false, + null, + widget + ); + config = res?.customConfig ?? config; + } + } } else if (linksTo) { const [id, slot] = linksTo["0"]; rerouteType = this.nodeData.nodes[id].outputs[slot].type; @@ -282,10 +313,11 @@ export class GroupNodeConfig { } } + config.forceInput = true; return { input: { required: { - [rerouteType]: [rerouteType, {}], + [rerouteType]: [rerouteType, config], }, }, output: [rerouteType], @@ -299,16 +331,17 @@ export class GroupNodeConfig { getInputConfig(node, inputName, seenInputs, config, extra) { let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName; + let key = name; let prefix = ""; // Special handling for primitive to include the title if it is set rather than just "value" if ((node.type === "PrimitiveNode" && node.title) || name in seenInputs) { prefix = `${node.title ?? node.type} `; - name = `${prefix}${inputName}`; + key = name = `${prefix}${inputName}`; if (name in seenInputs) { name = `${prefix}${seenInputs[name]} ${inputName}`; } } - seenInputs[name] = (seenInputs[name] ?? 1) + 1; + seenInputs[key] = (seenInputs[key] ?? 1) + 1; if (inputName === "seed" || inputName === "noise_seed") { if (!extra) extra = {}; @@ -420,10 +453,18 @@ export class GroupNodeConfig { defaultInput: true, }); this.nodeDef.input.required[name] = config; + this.newToOldWidgetMap[name] = { node, inputName }; + + if (!this.oldToNewWidgetMap[node.index]) { + this.oldToNewWidgetMap[node.index] = {}; + } + this.oldToNewWidgetMap[node.index][inputName] = name; + inputMap[slots.length + i] = this.inputCount++; } } + #convertedToProcess = []; processNodeInputs(node, seenInputs, inputs) { const inputMapping = []; @@ -434,7 +475,11 @@ export class GroupNodeConfig { const linksTo = this.linksTo[node.index] ?? {}; const inputMap = (this.oldToNewInputMap[node.index] = {}); this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs); - this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs); + + // Converted inputs have to be processed after all other nodes as they'll be at the end of the list + this.#convertedToProcess.push(() => + this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs) + ); return inputMapping; } @@ -597,11 +642,19 @@ export class GroupNodeHandler { const output = this.groupData.newToOldOutputMap[link.origin_slot]; let innerNode = this.innerNodes[output.node.index]; let l; - while (innerNode.type === "Reroute") { + while (innerNode?.type === "Reroute") { l = innerNode.getInputLink(0); innerNode = innerNode.getInputNode(0); } + if (!innerNode) { + return null; + } + + if (l && GroupNodeHandler.isGroupNode(innerNode)) { + return innerNode.updateLink(l); + } + link.origin_id = innerNode.id; link.origin_slot = l?.origin_slot ?? output.slot; return link; @@ -665,6 +718,8 @@ export class GroupNodeHandler { top = newNode.pos[1]; } + if (!newNode.widgets) continue; + const map = this.groupData.oldToNewWidgetMap[innerNode.index]; if (map) { const widgets = Object.keys(map); @@ -721,7 +776,7 @@ export class GroupNodeHandler { } }; - const reconnectOutputs = () => { + const reconnectOutputs = (selectedIds) => { for (let groupOutputId = 0; groupOutputId < node.outputs?.length; groupOutputId++) { const output = node.outputs[groupOutputId]; if (!output.links) continue; @@ -861,7 +916,7 @@ export class GroupNodeHandler { if (innerNode.type === "PrimitiveNode") { innerNode.primitiveValue = newValue; const primitiveLinked = this.groupData.primitiveToWidget[old.node.index]; - for (const linked of primitiveLinked) { + for (const linked of primitiveLinked ?? []) { const node = this.innerNodes[linked.nodeId]; const widget = node.widgets.find((w) => w.name === linked.inputName); @@ -870,6 +925,18 @@ export class GroupNodeHandler { } } continue; + } else if (innerNode.type === "Reroute") { + const rerouteLinks = this.groupData.linksFrom[old.node.index]; + for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) { + const node = this.innerNodes[targetNodeId]; + const input = node.inputs[targetSlot]; + if (input.widget) { + const widget = node.widgets?.find((w) => w.name === input.widget.name); + if (widget) { + widget.value = newValue; + } + } + } } const widget = innerNode.widgets?.find((w) => w.name === old.inputName); @@ -897,33 +964,58 @@ export class GroupNodeHandler { this.node.widgets[targetWidgetIndex + i].value = primitiveNode.widgets[i].value; } } + return true; } + populateReroute(node, nodeId, map) { + if (node.type !== "Reroute") return; + + const link = this.groupData.linksFrom[nodeId]?.[0]?.[0]; + if (!link) return; + const [, , targetNodeId, targetNodeSlot] = link; + const targetNode = this.groupData.nodeData.nodes[targetNodeId]; + const inputs = targetNode.inputs; + const targetWidget = inputs?.[targetNodeSlot].widget; + if (!targetWidget) return; + + const offset = inputs.length - (targetNode.widgets_values?.length ?? 0); + const v = targetNode.widgets_values?.[targetNodeSlot - offset]; + if (v == null) return; + + const widgetName = Object.values(map)[0]; + const widget = this.node.widgets.find(w => w.name === widgetName); + if(widget) { + widget.value = v; + } + } + + populateWidgets() { + if (!this.node.widgets) return; + for (let nodeId = 0; nodeId < this.groupData.nodeData.nodes.length; nodeId++) { const node = this.groupData.nodeData.nodes[nodeId]; - - if (!node.widgets_values?.length) continue; - - const map = this.groupData.oldToNewWidgetMap[nodeId]; + const map = this.groupData.oldToNewWidgetMap[nodeId] ?? {}; const widgets = Object.keys(map); + if (!node.widgets_values?.length) { + // special handling for populating values into reroutes + // this allows primitives connect to them to pick up the correct value + this.populateReroute(node, nodeId, map); + continue; + } + let linkedShift = 0; for (let i = 0; i < widgets.length; i++) { const oldName = widgets[i]; const newName = map[oldName]; const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName); const mainWidget = this.node.widgets[widgetIndex]; - if (!newName) { - // New name will be null if its a converted widget - this.populatePrimitive(node, nodeId, oldName, i, linkedShift); - + if (this.populatePrimitive(node, nodeId, oldName, i, linkedShift) || widgetIndex === -1) { // Find the inner widget and shift by the number of linked widgets as they will have been removed too const innerWidget = this.innerNodes[nodeId].widgets?.find((w) => w.name === oldName); - linkedShift += innerWidget.linkedWidgets?.length ?? 0; - continue; + linkedShift += innerWidget?.linkedWidgets?.length ?? 0; } - if (widgetIndex === -1) { continue; } diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index 8ace79562..bb2f16d42 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -33,6 +33,18 @@ function loadedImageToBlob(image) { return blob; } +function loadImage(imagePath) { + return new Promise((resolve, reject) => { + const image = new Image(); + + image.onload = function() { + resolve(image); + }; + + image.src = imagePath; + }); +} + async function uploadMask(filepath, formData) { await api.fetchApi('/upload/mask', { method: 'POST', @@ -50,25 +62,25 @@ async function uploadMask(filepath, formData) { ClipspaceDialog.invalidatePreview(); } -function prepareRGB(image, backupCanvas, backupCtx) { +function prepare_mask(image, maskCanvas, maskCtx) { // paste mask data into alpha channel - backupCtx.drawImage(image, 0, 0, backupCanvas.width, backupCanvas.height); - const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height); + maskCtx.drawImage(image, 0, 0, maskCanvas.width, maskCanvas.height); + const maskData = maskCtx.getImageData(0, 0, maskCanvas.width, maskCanvas.height); - // refine mask image - for (let i = 0; i < backupData.data.length; i += 4) { - if(backupData.data[i+3] == 255) - backupData.data[i+3] = 0; + // invert mask + for (let i = 0; i < maskData.data.length; i += 4) { + if(maskData.data[i+3] == 255) + maskData.data[i+3] = 0; else - backupData.data[i+3] = 255; + maskData.data[i+3] = 255; - backupData.data[i] = 0; - backupData.data[i+1] = 0; - backupData.data[i+2] = 0; + maskData.data[i] = 0; + maskData.data[i+1] = 0; + maskData.data[i+2] = 0; } - backupCtx.globalCompositeOperation = 'source-over'; - backupCtx.putImageData(backupData, 0, 0); + maskCtx.globalCompositeOperation = 'source-over'; + maskCtx.putImageData(maskData, 0, 0); } class MaskEditorDialog extends ComfyDialog { @@ -155,10 +167,6 @@ class MaskEditorDialog extends ComfyDialog { // If it is specified as relative, using it only as a hidden placeholder for padding is recommended // to prevent anomalies where it exceeds a certain size and goes outside of the window. - var placeholder = document.createElement("div"); - placeholder.style.position = "relative"; - placeholder.style.height = "50px"; - var bottom_panel = document.createElement("div"); bottom_panel.style.position = "absolute"; bottom_panel.style.bottom = "0px"; @@ -180,18 +188,16 @@ class MaskEditorDialog extends ComfyDialog { this.brush = brush; this.element.appendChild(imgCanvas); this.element.appendChild(maskCanvas); - this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button this.element.appendChild(bottom_panel); document.body.appendChild(brush); - var brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => { + this.brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => { self.brush_size = event.target.value; self.updateBrushPreview(self, null, null); }); var clearButton = this.createLeftButton("Clear", () => { self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height); - self.backupCtx.clearRect(0, 0, self.backupCanvas.width, self.backupCanvas.height); }); var cancelButton = this.createRightButton("Cancel", () => { document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); @@ -207,40 +213,42 @@ class MaskEditorDialog extends ComfyDialog { this.element.appendChild(imgCanvas); this.element.appendChild(maskCanvas); - this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button this.element.appendChild(bottom_panel); bottom_panel.appendChild(clearButton); bottom_panel.appendChild(this.saveButton); bottom_panel.appendChild(cancelButton); - bottom_panel.appendChild(brush_size_slider); + bottom_panel.appendChild(this.brush_size_slider); + + imgCanvas.style.position = "absolute"; + maskCanvas.style.position = "absolute"; - imgCanvas.style.position = "relative"; imgCanvas.style.top = "200"; imgCanvas.style.left = "0"; - maskCanvas.style.position = "absolute"; + maskCanvas.style.top = imgCanvas.style.top; + maskCanvas.style.left = imgCanvas.style.left; } - show() { + async show() { + this.zoom_ratio = 1.0; + this.pan_x = 0; + this.pan_y = 0; + if(!this.is_layout_created) { // layout const imgCanvas = document.createElement('canvas'); const maskCanvas = document.createElement('canvas'); - const backupCanvas = document.createElement('canvas'); imgCanvas.id = "imageCanvas"; maskCanvas.id = "maskCanvas"; - backupCanvas.id = "backupCanvas"; this.setlayout(imgCanvas, maskCanvas); // prepare content this.imgCanvas = imgCanvas; this.maskCanvas = maskCanvas; - this.backupCanvas = backupCanvas; - this.maskCtx = maskCanvas.getContext('2d'); - this.backupCtx = backupCanvas.getContext('2d'); + this.maskCtx = maskCanvas.getContext('2d', {willReadFrequently: true }); this.setEventHandler(maskCanvas); @@ -252,6 +260,8 @@ class MaskEditorDialog extends ComfyDialog { mutations.forEach(function(mutation) { if (mutation.type === 'attributes' && mutation.attributeName === 'style') { if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') { + document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); + self.brush.style.display = "none"; ComfyApp.onClipspaceEditorClosed(); } @@ -264,7 +274,8 @@ class MaskEditorDialog extends ComfyDialog { observer.observe(this.element, config); } - this.setImages(this.imgCanvas, this.backupCanvas); + // The keydown event needs to be reconfigured when closing the dialog as it gets removed. + document.addEventListener('keydown', MaskEditorDialog.handleKeyDown); if(ComfyApp.clipspace_return_node) { this.saveButton.innerText = "Save to node"; @@ -275,97 +286,157 @@ class MaskEditorDialog extends ComfyDialog { this.saveButton.disabled = false; this.element.style.display = "block"; + this.element.style.width = "85%"; + this.element.style.margin = "0 7.5%"; + this.element.style.height = "100vh"; + this.element.style.top = "50%"; + this.element.style.left = "42%"; this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority. + + await this.setImages(this.imgCanvas); + + this.is_visible = true; } isOpened() { return this.element.style.display == "block"; } - setImages(imgCanvas, backupCanvas) { - const imgCtx = imgCanvas.getContext('2d'); - const backupCtx = backupCanvas.getContext('2d'); + invalidateCanvas(orig_image, mask_image) { + this.imgCanvas.width = orig_image.width; + this.imgCanvas.height = orig_image.height; + + this.maskCanvas.width = orig_image.width; + this.maskCanvas.height = orig_image.height; + + let imgCtx = this.imgCanvas.getContext('2d', {willReadFrequently: true }); + let maskCtx = this.maskCanvas.getContext('2d', {willReadFrequently: true }); + + imgCtx.drawImage(orig_image, 0, 0, orig_image.width, orig_image.height); + prepare_mask(mask_image, this.maskCanvas, maskCtx); + } + + async setImages(imgCanvas) { + let self = this; + + const imgCtx = imgCanvas.getContext('2d', {willReadFrequently: true }); const maskCtx = this.maskCtx; const maskCanvas = this.maskCanvas; - backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height); maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height); // image load - const orig_image = new Image(); - window.addEventListener("resize", () => { - // repositioning - imgCanvas.width = window.innerWidth - 250; - imgCanvas.height = window.innerHeight - 200; - - // redraw image - let drawWidth = orig_image.width; - let drawHeight = orig_image.height; - if (orig_image.width > imgCanvas.width) { - drawWidth = imgCanvas.width; - drawHeight = (drawWidth / orig_image.width) * orig_image.height; - } - - if (drawHeight > imgCanvas.height) { - drawHeight = imgCanvas.height; - drawWidth = (drawHeight / orig_image.height) * orig_image.width; - } - - imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight); - - // update mask - maskCanvas.width = drawWidth; - maskCanvas.height = drawHeight; - maskCanvas.style.top = imgCanvas.offsetTop + "px"; - maskCanvas.style.left = imgCanvas.offsetLeft + "px"; - backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height); - maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height); - }); - const filepath = ComfyApp.clipspace.images; - const touched_image = new Image(); - - touched_image.onload = function() { - backupCanvas.width = touched_image.width; - backupCanvas.height = touched_image.height; - - prepareRGB(touched_image, backupCanvas, backupCtx); - }; - const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src) alpha_url.searchParams.delete('channel'); alpha_url.searchParams.delete('preview'); alpha_url.searchParams.set('channel', 'a'); - touched_image.src = alpha_url; + let mask_image = await loadImage(alpha_url); // original image load - orig_image.onload = function() { - window.dispatchEvent(new Event('resize')); - }; - const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src); rgb_url.searchParams.delete('channel'); rgb_url.searchParams.set('channel', 'rgb'); - orig_image.src = rgb_url; - this.image = orig_image; + this.image = new Image(); + this.image.onload = function() { + maskCanvas.width = self.image.width; + maskCanvas.height = self.image.height; + + self.invalidateCanvas(self.image, mask_image); + self.initializeCanvasPanZoom(); + }; + this.image.src = rgb_url; } - setEventHandler(maskCanvas) { - maskCanvas.addEventListener("contextmenu", (event) => { - event.preventDefault(); - }); + initializeCanvasPanZoom() { + // set initialize + let drawWidth = this.image.width; + let drawHeight = this.image.height; + let width = this.element.clientWidth; + let height = this.element.clientHeight; + + if (this.image.width > width) { + drawWidth = width; + drawHeight = (drawWidth / this.image.width) * this.image.height; + } + + if (drawHeight > height) { + drawHeight = height; + drawWidth = (drawHeight / this.image.height) * this.image.width; + } + + this.zoom_ratio = drawWidth/this.image.width; + + const canvasX = (width - drawWidth) / 2; + const canvasY = (height - drawHeight) / 2; + this.pan_x = canvasX; + this.pan_y = canvasY; + + this.invalidatePanZoom(); + } + + + invalidatePanZoom() { + let raw_width = this.image.width * this.zoom_ratio; + let raw_height = this.image.height * this.zoom_ratio; + + if(this.pan_x + raw_width < 10) { + this.pan_x = 10 - raw_width; + } + + if(this.pan_y + raw_height < 10) { + this.pan_y = 10 - raw_height; + } + + let width = `${raw_width}px`; + let height = `${raw_height}px`; + + let left = `${this.pan_x}px`; + let top = `${this.pan_y}px`; + + this.maskCanvas.style.width = width; + this.maskCanvas.style.height = height; + this.maskCanvas.style.left = left; + this.maskCanvas.style.top = top; + + this.imgCanvas.style.width = width; + this.imgCanvas.style.height = height; + this.imgCanvas.style.left = left; + this.imgCanvas.style.top = top; + } + + + setEventHandler(maskCanvas) { const self = this; - maskCanvas.addEventListener('wheel', (event) => this.handleWheelEvent(self,event)); - maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event)); - document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp); - maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event)); - maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event)); - maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; }); - maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; }); - document.addEventListener('keydown', MaskEditorDialog.handleKeyDown); + + if(!this.handler_registered) { + maskCanvas.addEventListener("contextmenu", (event) => { + event.preventDefault(); + }); + + this.element.addEventListener('wheel', (event) => this.handleWheelEvent(self,event)); + this.element.addEventListener('pointermove', (event) => this.pointMoveEvent(self,event)); + this.element.addEventListener('touchmove', (event) => this.pointMoveEvent(self,event)); + + this.element.addEventListener('dragstart', (event) => { + if(event.ctrlKey) { + event.preventDefault(); + } + }); + + maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event)); + maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event)); + maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event)); + maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; }); + maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; }); + + document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp); + + this.handler_registered = true; + } } brush_size = 10; @@ -378,8 +449,10 @@ class MaskEditorDialog extends ComfyDialog { const self = MaskEditorDialog.instance; if (event.key === ']') { self.brush_size = Math.min(self.brush_size+2, 100); + self.brush_slider_input.value = self.brush_size; } else if (event.key === '[') { self.brush_size = Math.max(self.brush_size-2, 1); + self.brush_slider_input.value = self.brush_size; } else if(event.key === 'Enter') { self.save(); } @@ -389,6 +462,10 @@ class MaskEditorDialog extends ComfyDialog { static handlePointerUp(event) { event.preventDefault(); + + this.mousedown_x = null; + this.mousedown_y = null; + MaskEditorDialog.instance.drawing_mode = false; } @@ -398,24 +475,70 @@ class MaskEditorDialog extends ComfyDialog { var centerX = self.cursorX; var centerY = self.cursorY; - brush.style.width = self.brush_size * 2 + "px"; - brush.style.height = self.brush_size * 2 + "px"; - brush.style.left = (centerX - self.brush_size) + "px"; - brush.style.top = (centerY - self.brush_size) + "px"; + brush.style.width = self.brush_size * 2 * this.zoom_ratio + "px"; + brush.style.height = self.brush_size * 2 * this.zoom_ratio + "px"; + brush.style.left = (centerX - self.brush_size * this.zoom_ratio) + "px"; + brush.style.top = (centerY - self.brush_size * this.zoom_ratio) + "px"; } handleWheelEvent(self, event) { - if(event.deltaY < 0) - self.brush_size = Math.min(self.brush_size+2, 100); - else - self.brush_size = Math.max(self.brush_size-2, 1); + event.preventDefault(); - self.brush_slider_input.value = self.brush_size; + if(event.ctrlKey) { + // zoom canvas + if(event.deltaY < 0) { + this.zoom_ratio = Math.min(10.0, this.zoom_ratio+0.2); + } + else { + this.zoom_ratio = Math.max(0.2, this.zoom_ratio-0.2); + } + + this.invalidatePanZoom(); + } + else { + // adjust brush size + if(event.deltaY < 0) + this.brush_size = Math.min(this.brush_size+2, 100); + else + this.brush_size = Math.max(this.brush_size-2, 1); + + this.brush_slider_input.value = this.brush_size; + + this.updateBrushPreview(this); + } + } + + pointMoveEvent(self, event) { + this.cursorX = event.pageX; + this.cursorY = event.pageY; self.updateBrushPreview(self); + + if(event.ctrlKey) { + event.preventDefault(); + self.pan_move(self, event); + } + } + + pan_move(self, event) { + if(event.buttons == 1) { + if(this.mousedown_x) { + let deltaX = this.mousedown_x - event.clientX; + let deltaY = this.mousedown_y - event.clientY; + + self.pan_x = this.mousedown_pan_x - deltaX; + self.pan_y = this.mousedown_pan_y - deltaY; + + self.invalidatePanZoom(); + } + } } draw_move(self, event) { + if(event.ctrlKey) { + return; + } + event.preventDefault(); this.cursorX = event.pageX; @@ -439,6 +562,9 @@ class MaskEditorDialog extends ComfyDialog { y = event.targetTouches[0].clientY - maskRect.top; } + x /= self.zoom_ratio; + y /= self.zoom_ratio; + var brush_size = this.brush_size; if(event instanceof PointerEvent && event.pointerType == 'pen') { brush_size *= event.pressure; @@ -489,8 +615,8 @@ class MaskEditorDialog extends ComfyDialog { } else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) { const maskRect = self.maskCanvas.getBoundingClientRect(); - const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left; - const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top; + const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio; + const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio; var brush_size = this.brush_size; if(event instanceof PointerEvent && event.pointerType == 'pen') { @@ -540,6 +666,17 @@ class MaskEditorDialog extends ComfyDialog { } handlePointerDown(self, event) { + if(event.ctrlKey) { + if (event.buttons == 1) { + this.mousedown_x = event.clientX; + this.mousedown_y = event.clientY; + + this.mousedown_pan_x = this.pan_x; + this.mousedown_pan_y = this.pan_y; + } + return; + } + var brush_size = this.brush_size; if(event instanceof PointerEvent && event.pointerType == 'pen') { brush_size *= event.pressure; @@ -551,8 +688,8 @@ class MaskEditorDialog extends ComfyDialog { event.preventDefault(); const maskRect = self.maskCanvas.getBoundingClientRect(); - const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left; - const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top; + const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio; + const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio; self.maskCtx.beginPath(); if (event.button == 0) { @@ -570,15 +707,18 @@ class MaskEditorDialog extends ComfyDialog { } async save() { - const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true}); + const backupCanvas = document.createElement('canvas'); + const backupCtx = backupCanvas.getContext('2d', {willReadFrequently:true}); + backupCanvas.width = this.image.width; + backupCanvas.height = this.image.height; - backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); + backupCtx.clearRect(0,0, backupCanvas.width, backupCanvas.height); backupCtx.drawImage(this.maskCanvas, 0, 0, this.maskCanvas.width, this.maskCanvas.height, - 0, 0, this.backupCanvas.width, this.backupCanvas.height); + 0, 0, backupCanvas.width, backupCanvas.height); // paste mask data into alpha channel - const backupData = backupCtx.getImageData(0, 0, this.backupCanvas.width, this.backupCanvas.height); + const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height); // refine mask image for (let i = 0; i < backupData.data.length; i += 4) { @@ -615,7 +755,7 @@ class MaskEditorDialog extends ComfyDialog { ComfyApp.clipspace.widgets[index].value = item; } - const dataURL = this.backupCanvas.toDataURL(); + const dataURL = backupCanvas.toDataURL(); const blob = dataURLToBlob(dataURL); let original_url = new URL(this.image.src); diff --git a/web/extensions/core/rerouteNode.js b/web/extensions/core/rerouteNode.js index 499a171da..4feff91e5 100644 --- a/web/extensions/core/rerouteNode.js +++ b/web/extensions/core/rerouteNode.js @@ -1,10 +1,11 @@ import { app } from "../../scripts/app.js"; +import { mergeIfValid, getWidgetConfig, setWidgetConfig } from "./widgetInputs.js"; // Node that allows you to redirect connections for cleaner graphs app.registerExtension({ name: "Comfy.RerouteNode", - registerCustomNodes() { + registerCustomNodes(app) { class RerouteNode { constructor() { if (!this.properties) { @@ -16,6 +17,12 @@ app.registerExtension({ this.addInput("", "*"); this.addOutput(this.properties.showOutputText ? "*" : "", "*"); + this.onAfterGraphConfigured = function () { + requestAnimationFrame(() => { + this.onConnectionsChange(LiteGraph.INPUT, null, true, null); + }); + }; + this.onConnectionsChange = function (type, index, connected, link_info) { this.applyOrientation(); @@ -47,6 +54,7 @@ app.registerExtension({ const linkId = currentNode.inputs[0].link; if (linkId !== null) { const link = app.graph.links[linkId]; + if (!link) return; const node = app.graph.getNodeById(link.origin_id); const type = node.constructor.type; if (type === "Reroute") { @@ -54,8 +62,7 @@ app.registerExtension({ // We've found a circle currentNode.disconnectInput(link.target_slot); currentNode = null; - } - else { + } else { // Move the previous node currentNode = node; } @@ -94,8 +101,11 @@ app.registerExtension({ updateNodes.push(node); } else { // We've found an output - const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null; - if (inputType && nodeOutType !== inputType) { + const nodeOutType = + node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type + ? node.inputs[link.target_slot].type + : null; + if (inputType && inputType !== "*" && nodeOutType !== inputType) { // The output doesnt match our input so disconnect it node.disconnectInput(link.target_slot); } else { @@ -111,6 +121,9 @@ app.registerExtension({ const displayType = inputType || outputType || "*"; const color = LGraphCanvas.link_type_colors[displayType]; + let widgetConfig; + let targetWidget; + let widgetType; // Update the types of each node for (const node of updateNodes) { // If we dont have an input type we are always wildcard but we'll show the output type @@ -125,10 +138,38 @@ app.registerExtension({ const link = app.graph.links[l]; if (link) { link.color = color; + + if (app.configuringGraph) continue; + const targetNode = app.graph.getNodeById(link.target_id); + const targetInput = targetNode.inputs?.[link.target_slot]; + if (targetInput?.widget) { + const config = getWidgetConfig(targetInput); + if (!widgetConfig) { + widgetConfig = config[1] ?? {}; + widgetType = config[0]; + } + if (!targetWidget) { + targetWidget = targetNode.widgets?.find((w) => w.name === targetInput.widget.name); + } + + const merged = mergeIfValid(targetInput, [config[0], widgetConfig]); + if (merged.customConfig) { + widgetConfig = merged.customConfig; + } + } } } } + for (const node of updateNodes) { + if (widgetConfig && outputType) { + node.inputs[0].widget = { name: "value" }; + setWidgetConfig(node.inputs[0], [widgetType ?? displayType, widgetConfig], targetWidget); + } else { + setWidgetConfig(node.inputs[0], null); + } + } + if (inputNode) { const link = app.graph.links[inputNode.inputs[0].link]; if (link) { @@ -173,8 +214,8 @@ app.registerExtension({ }, { // naming is inverted with respect to LiteGraphNode.horizontal - // LiteGraphNode.horizontal == true means that - // each slot in the inputs and outputs are layed out horizontally, + // LiteGraphNode.horizontal == true means that + // each slot in the inputs and outputs are layed out horizontally, // which is the opposite of the visual orientation of the inputs and outputs as a node content: "Set " + (this.properties.horizontal ? "Horizontal" : "Vertical"), callback: () => { @@ -187,7 +228,7 @@ app.registerExtension({ applyOrientation() { this.horizontal = this.properties.horizontal; if (this.horizontal) { - // we correct the input position, because LiteGraphNode.horizontal + // we correct the input position, because LiteGraphNode.horizontal // doesn't account for title presence // which reroute nodes don't have this.inputs[0].pos = [this.size[0] / 2, 0]; diff --git a/web/extensions/core/saveImageExtraOutput.js b/web/extensions/core/saveImageExtraOutput.js index 99e2213bf..a0506b43b 100644 --- a/web/extensions/core/saveImageExtraOutput.js +++ b/web/extensions/core/saveImageExtraOutput.js @@ -1,5 +1,5 @@ import { app } from "../../scripts/app.js"; - +import { applyTextReplacements } from "../../scripts/utils.js"; // Use widget values and dates in output filenames app.registerExtension({ @@ -7,84 +7,19 @@ app.registerExtension({ async beforeRegisterNodeDef(nodeType, nodeData, app) { if (nodeData.name === "SaveImage") { const onNodeCreated = nodeType.prototype.onNodeCreated; - - // Simple date formatter - const parts = { - d: (d) => d.getDate(), - M: (d) => d.getMonth() + 1, - h: (d) => d.getHours(), - m: (d) => d.getMinutes(), - s: (d) => d.getSeconds(), - }; - const format = - Object.keys(parts) - .map((k) => k + k + "?") - .join("|") + "|yyy?y?"; - - function formatDate(text, date) { - return text.replace(new RegExp(format, "g"), function (text) { - if (text === "yy") return (date.getFullYear() + "").substring(2); - if (text === "yyyy") return date.getFullYear(); - if (text[0] in parts) { - const p = parts[text[0]](date); - return (p + "").padStart(text.length, "0"); - } - return text; - }); - } - - // When the SaveImage node is created we want to override the serialization of the output name widget to run our S&R + // When the SaveImage node is created we want to override the serialization of the output name widget to run our S&R nodeType.prototype.onNodeCreated = function () { const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined; const widget = this.widgets.find((w) => w.name === "filename_prefix"); widget.serializeValue = () => { - return widget.value.replace(/%([^%]+)%/g, function (match, text) { - const split = text.split("."); - if (split.length !== 2) { - // Special handling for dates - if (split[0].startsWith("date:")) { - return formatDate(split[0].substring(5), new Date()); - } - - if (text !== "width" && text !== "height") { - // Dont warn on standard replacements - console.warn("Invalid replacement pattern", text); - } - return match; - } - - // Find node with matching S&R property name - let nodes = app.graph._nodes.filter((n) => n.properties?.["Node name for S&R"] === split[0]); - // If we cant, see if there is a node with that title - if (!nodes.length) { - nodes = app.graph._nodes.filter((n) => n.title === split[0]); - } - if (!nodes.length) { - console.warn("Unable to find node", split[0]); - return match; - } - - if (nodes.length > 1) { - console.warn("Multiple nodes matched", split[0], "using first match"); - } - - const node = nodes[0]; - - const widget = node.widgets?.find((w) => w.name === split[1]); - if (!widget) { - console.warn("Unable to find widget", split[1], "on node", split[0], node); - return match; - } - - return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_"); - }); + return applyTextReplacements(app, widget.value); }; return r; }; } else { - // When any other node is created add a property to alias the node + // When any other node is created add a property to alias the node const onNodeCreated = nodeType.prototype.onNodeCreated; nodeType.prototype.onNodeCreated = function () { const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined; diff --git a/web/extensions/core/undoRedo.js b/web/extensions/core/undoRedo.js index c6613b0f0..3cb137520 100644 --- a/web/extensions/core/undoRedo.js +++ b/web/extensions/core/undoRedo.js @@ -71,24 +71,21 @@ function graphEqual(a, b, root = true) { } const undoRedo = async (e) => { + const updateState = async (source, target) => { + const prevState = source.pop(); + if (prevState) { + target.push(activeState); + isOurLoad = true; + await app.loadGraphData(prevState, false); + activeState = prevState; + } + } if (e.ctrlKey || e.metaKey) { if (e.key === "y") { - const prevState = redo.pop(); - if (prevState) { - undo.push(activeState); - isOurLoad = true; - await app.loadGraphData(prevState); - activeState = prevState; - } + updateState(redo, undo); return true; } else if (e.key === "z") { - const prevState = undo.pop(); - if (prevState) { - redo.push(activeState); - isOurLoad = true; - await app.loadGraphData(prevState); - activeState = prevState; - } + updateState(undo, redo); return true; } } diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index b6fa411f7..3f1c1f8c1 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -1,10 +1,16 @@ import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js"; import { app } from "../../scripts/app.js"; +import { applyTextReplacements } from "../../scripts/utils.js"; const CONVERTED_TYPE = "converted-widget"; const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"]; const CONFIG = Symbol(); const GET_CONFIG = Symbol(); +const TARGET = Symbol(); // Used for reroutes to specify the real target widget + +export function getWidgetConfig(slot) { + return slot.widget[CONFIG] ?? slot.widget[GET_CONFIG](); +} function getConfig(widgetName) { const { nodeData } = this.constructor; @@ -100,7 +106,6 @@ function getWidgetType(config) { return { type }; } - function isValidCombo(combo, obj) { // New input isnt a combo if (!(obj instanceof Array)) { @@ -121,6 +126,31 @@ function isValidCombo(combo, obj) { return true; } +export function setWidgetConfig(slot, config, target) { + if (!slot.widget) return; + if (config) { + slot.widget[GET_CONFIG] = () => config; + slot.widget[TARGET] = target; + } else { + delete slot.widget; + } + + if (slot.link) { + const link = app.graph.links[slot.link]; + if (link) { + const originNode = app.graph.getNodeById(link.origin_id); + if (originNode.type === "PrimitiveNode") { + if (config) { + originNode.recreateWidget(); + } else if(!app.configuringGraph) { + originNode.disconnectOutput(0); + originNode.onLastDisconnect(); + } + } + } + } +} + export function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) { if (!config1) { config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG](); @@ -150,7 +180,7 @@ export function mergeIfValid(output, config2, forceUpdate, recreateWidget, confi const isNumber = config1[0] === "INT" || config1[0] === "FLOAT"; for (const k of keys.values()) { - if (k !== "default" && k !== "forceInput" && k !== "defaultInput") { + if (k !== "default" && k !== "forceInput" && k !== "defaultInput" && k !== "control_after_generate" && k !== "multiline") { let v1 = config1[1][k]; let v2 = config2[1]?.[k]; @@ -405,11 +435,16 @@ app.registerExtension({ }; }, registerCustomNodes() { + const replacePropertyName = "Run widget replace on values"; class PrimitiveNode { constructor() { this.addOutput("connect to widget input", "*"); this.serialize_widgets = true; this.isVirtualNode = true; + + if (!this.properties || !(replacePropertyName in this.properties)) { + this.addProperty(replacePropertyName, false, "boolean"); + } } applyToGraph(extraLinks = []) { @@ -430,18 +465,29 @@ app.registerExtension({ } let links = [...get_links(this).map((l) => app.graph.links[l]), ...extraLinks]; + let v = this.widgets?.[0].value; + if(v && this.properties[replacePropertyName]) { + v = applyTextReplacements(app, v); + } + // For each output link copy our value over the original widget value for (const linkInfo of links) { const node = this.graph.getNodeById(linkInfo.target_id); const input = node.inputs[linkInfo.target_slot]; - const widgetName = input.widget.name; - if (widgetName) { - const widget = node.widgets.find((w) => w.name === widgetName); - if (widget) { - widget.value = this.widgets[0].value; - if (widget.callback) { - widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {}); - } + let widget; + if (input.widget[TARGET]) { + widget = input.widget[TARGET]; + } else { + const widgetName = input.widget.name; + if (widgetName) { + widget = node.widgets.find((w) => w.name === widgetName); + } + } + + if (widget) { + widget.value = v; + if (widget.callback) { + widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {}); } } } @@ -494,14 +540,13 @@ app.registerExtension({ this.#mergeWidgetConfig(); if (!links?.length) { - this.#onLastDisconnect(); + this.onLastDisconnect(); } } } onConnectOutput(slot, type, input, target_node, target_slot) { // Fires before the link is made allowing us to reject it if it isn't valid - // No widget, we cant connect if (!input.widget) { if (!(input.type in ComfyWidgets)) return false; @@ -519,6 +564,10 @@ app.registerExtension({ #onFirstConnection(recreating) { // First connection can fire before the graph is ready on initial load so random things can be missing + if (!this.outputs[0].links) { + this.onLastDisconnect(); + return; + } const linkId = this.outputs[0].links[0]; const link = this.graph.links[linkId]; if (!link) return; @@ -546,10 +595,10 @@ app.registerExtension({ this.outputs[0].name = type; this.outputs[0].widget = widget; - this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating); + this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating, widget[TARGET]); } - #createWidget(inputData, node, widgetName, recreating) { + #createWidget(inputData, node, widgetName, recreating, targetWidget) { let type = inputData[0]; if (type instanceof Array) { @@ -563,7 +612,9 @@ app.registerExtension({ widget = this.addWidget(type, "value", null, () => {}, {}); } - if (node?.widgets && widget) { + if (targetWidget) { + widget.value = targetWidget.value; + } else if (node?.widgets && widget) { const theirWidget = node.widgets.find((w) => w.name === widgetName); if (theirWidget) { widget.value = theirWidget.value; @@ -577,11 +628,19 @@ app.registerExtension({ } addValueControlWidgets(this, widget, control_value, undefined, inputData); let filter = this.widgets_values?.[2]; - if(filter && this.widgets.length === 3) { + if (filter && this.widgets.length === 3) { this.widgets[2].value = filter; } } + // Restore any saved control values + const controlValues = this.controlValues; + if(this.lastType === this.widgets[0].type && controlValues?.length === this.widgets.length - 1) { + for(let i = 0; i < controlValues.length; i++) { + this.widgets[i + 1].value = controlValues[i]; + } + } + // When our value changes, update other widgets to reflect our changes // e.g. so LoadImage shows correct image const callback = widget.callback; @@ -610,12 +669,14 @@ app.registerExtension({ } } - #recreateWidget() { - const values = this.widgets.map((w) => w.value); + recreateWidget() { + const values = this.widgets?.map((w) => w.value); this.#removeWidgets(); this.#onFirstConnection(true); - for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i]; - return this.widgets[0]; + if (values?.length) { + for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i]; + } + return this.widgets?.[0]; } #mergeWidgetConfig() { @@ -631,7 +692,7 @@ app.registerExtension({ if (links?.length < 2 && hasConfig) { // Copy the widget options from the source if (links.length) { - this.#recreateWidget(); + this.recreateWidget(); } return; @@ -657,7 +718,7 @@ app.registerExtension({ // Only allow connections where the configs match const output = this.outputs[0]; const config2 = input.widget[GET_CONFIG](); - return !!mergeIfValid.call(this, output, config2, forceUpdate, this.#recreateWidget); + return !!mergeIfValid.call(this, output, config2, forceUpdate, this.recreateWidget); } #removeWidgets() { @@ -668,11 +729,20 @@ app.registerExtension({ w.onRemove(); } } + + // Temporarily store the current values in case the node is being recreated + // e.g. by group node conversion + this.controlValues = []; + this.lastType = this.widgets[0]?.type; + for(let i = 1; i < this.widgets.length; i++) { + this.controlValues.push(this.widgets[i].value); + } + setTimeout(() => { delete this.lastType; delete this.controlValues }, 15); this.widgets.length = 0; } } - #onLastDisconnect() { + onLastDisconnect() { // We cant remove + re-add the output here as if you drag a link over the same link // it removes, then re-adds, causing it to break this.outputs[0].type = "*"; diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index f571edb30..434c4a83b 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -48,7 +48,7 @@ EVENT_LINK_COLOR: "#A86", CONNECTING_LINK_COLOR: "#AFA", - MAX_NUMBER_OF_NODES: 1000, //avoid infinite loops + MAX_NUMBER_OF_NODES: 10000, //avoid infinite loops DEFAULT_POSITION: [100, 100], //default node position VALID_SHAPES: ["default", "box", "round", "card"], //,"circle" @@ -3788,16 +3788,42 @@ /** * returns the bounding of the object, used for rendering purposes - * bounding is: [topleft_cornerx, topleft_cornery, width, height] * @method getBounding - * @return {Float32Array[4]} the total size + * @param out {Float32Array[4]?} [optional] a place to store the output, to free garbage + * @param compute_outer {boolean?} [optional] set to true to include the shadow and connection points in the bounding calculation + * @return {Float32Array[4]} the bounding box in format of [topleft_cornerx, topleft_cornery, width, height] */ - LGraphNode.prototype.getBounding = function(out) { + LGraphNode.prototype.getBounding = function(out, compute_outer) { out = out || new Float32Array(4); - out[0] = this.pos[0] - 4; - out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT; - out[2] = this.flags.collapsed ? (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) : this.size[0] + 4; - out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT; + const nodePos = this.pos; + const isCollapsed = this.flags.collapsed; + const nodeSize = this.size; + + let left_offset = 0; + // 1 offset due to how nodes are rendered + let right_offset = 1 ; + let top_offset = 0; + let bottom_offset = 0; + + if (compute_outer) { + // 4 offset for collapsed node connection points + left_offset = 4; + // 6 offset for right shadow and collapsed node connection points + right_offset = 6 + left_offset; + // 4 offset for collapsed nodes top connection points + top_offset = 4; + // 5 offset for bottom shadow and collapsed node connection points + bottom_offset = 5 + top_offset; + } + + out[0] = nodePos[0] - left_offset; + out[1] = nodePos[1] - LiteGraph.NODE_TITLE_HEIGHT - top_offset; + out[2] = isCollapsed ? + (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) + right_offset : + nodeSize[0] + right_offset; + out[3] = isCollapsed ? + LiteGraph.NODE_TITLE_HEIGHT + bottom_offset : + nodeSize[1] + LiteGraph.NODE_TITLE_HEIGHT + bottom_offset; if (this.onBounding) { this.onBounding(out); @@ -7674,7 +7700,7 @@ LGraphNode.prototype.executeAction = function(action) continue; } - if (!overlapBounding(this.visible_area, n.getBounding(temp))) { + if (!overlapBounding(this.visible_area, n.getBounding(temp, true))) { continue; } //out of the visible area @@ -11336,6 +11362,7 @@ LGraphNode.prototype.executeAction = function(action) name_element.innerText = title; var value_element = dialog.querySelector(".value"); value_element.value = value; + value_element.select(); var input = value_element; input.addEventListener("keydown", function(e) { diff --git a/web/scripts/app.js b/web/scripts/app.js index 5faf41fb3..7353f5a3b 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1559,9 +1559,12 @@ export class ComfyApp { /** * Populates the graph with the specified workflow data * @param {*} graphData A serialized graph object + * @param { boolean } clean If the graph state, e.g. images, should be cleared */ - async loadGraphData(graphData) { - this.clean(); + async loadGraphData(graphData, clean = true) { + if (clean !== false) { + this.clean(); + } let reset_invalid_values = false; if (!graphData) { @@ -1771,15 +1774,26 @@ export class ComfyApp { if (parent?.updateLink) { link = parent.updateLink(link); } - inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)]; + if (link) { + inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)]; + } } } } - output[String(node.id)] = { + let node_data = { inputs, class_type: node.comfyClass, }; + + if (this.ui.settings.getSettingValue("Comfy.DevMode")) { + // Ignored by the backend. + node_data["_meta"] = { + title: node.title, + } + } + + output[String(node.id)] = node_data; } } @@ -2006,12 +2020,8 @@ export class ComfyApp { async refreshComboInNodes() { const defs = await api.getNodeDefs(); - for(const nodeId in LiteGraph.registered_node_types) { - const node = LiteGraph.registered_node_types[nodeId]; - const nodeDef = defs[nodeId]; - if(!nodeDef) continue; - - node.nodeData = nodeDef; + for (const nodeId in defs) { + this.registerNodeDef(nodeId, defs[nodeId]); } for(let nodeNum in this.graph._nodes) { diff --git a/web/scripts/domWidget.js b/web/scripts/domWidget.js index e919428a0..eb0742d38 100644 --- a/web/scripts/domWidget.js +++ b/web/scripts/domWidget.js @@ -177,6 +177,7 @@ LGraphCanvas.prototype.computeVisibleNodes = function () { for (const w of node.widgets) { if (w.element) { w.element.hidden = hidden; + w.element.style.display = hidden ? "none" : undefined; if (hidden) { w.options.onHide?.(w); } diff --git a/web/scripts/utils.js b/web/scripts/utils.js new file mode 100644 index 000000000..401aca9e4 --- /dev/null +++ b/web/scripts/utils.js @@ -0,0 +1,67 @@ +// Simple date formatter +const parts = { + d: (d) => d.getDate(), + M: (d) => d.getMonth() + 1, + h: (d) => d.getHours(), + m: (d) => d.getMinutes(), + s: (d) => d.getSeconds(), +}; +const format = + Object.keys(parts) + .map((k) => k + k + "?") + .join("|") + "|yyy?y?"; + +function formatDate(text, date) { + return text.replace(new RegExp(format, "g"), function (text) { + if (text === "yy") return (date.getFullYear() + "").substring(2); + if (text === "yyyy") return date.getFullYear(); + if (text[0] in parts) { + const p = parts[text[0]](date); + return (p + "").padStart(text.length, "0"); + } + return text; + }); +} + +export function applyTextReplacements(app, value) { + return value.replace(/%([^%]+)%/g, function (match, text) { + const split = text.split("."); + if (split.length !== 2) { + // Special handling for dates + if (split[0].startsWith("date:")) { + return formatDate(split[0].substring(5), new Date()); + } + + if (text !== "width" && text !== "height") { + // Dont warn on standard replacements + console.warn("Invalid replacement pattern", text); + } + return match; + } + + // Find node with matching S&R property name + let nodes = app.graph._nodes.filter((n) => n.properties?.["Node name for S&R"] === split[0]); + // If we cant, see if there is a node with that title + if (!nodes.length) { + nodes = app.graph._nodes.filter((n) => n.title === split[0]); + } + if (!nodes.length) { + console.warn("Unable to find node", split[0]); + return match; + } + + if (nodes.length > 1) { + console.warn("Multiple nodes matched", split[0], "using first match"); + } + + const node = nodes[0]; + + const widget = node.widgets?.find((w) => w.name === split[1]); + if (!widget) { + console.warn("Unable to find widget", split[1], "on node", split[0], node); + return match; + } + + return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_"); + }); +}