diff --git a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat
deleted file mode 100755
index 94f5d1023..000000000
--- a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat
+++ /dev/null
@@ -1,3 +0,0 @@
-..\python_embeded\python.exe .\update.py ..\ComfyUI\
-..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2
-pause
diff --git a/.ci/nightly/windows_base_files/run_nvidia_gpu.bat b/.ci/nightly/windows_base_files/run_nvidia_gpu.bat
deleted file mode 100755
index 8ee2f3402..000000000
--- a/.ci/nightly/windows_base_files/run_nvidia_gpu.bat
+++ /dev/null
@@ -1,2 +0,0 @@
-.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --use-pytorch-cross-attention
-pause
diff --git a/.github/workflows/test-ui.yaml b/.github/workflows/test-ui.yaml
index 950691755..4b8b97934 100644
--- a/.github/workflows/test-ui.yaml
+++ b/.github/workflows/test-ui.yaml
@@ -22,5 +22,5 @@ jobs:
run: |
npm ci
npm run test:generate
- npm test
+ npm test -- --verbose
working-directory: ./tests-ui
diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml
index b793f7fe2..90e09d27a 100644
--- a/.github/workflows/windows_release_nightly_pytorch.yml
+++ b/.github/workflows/windows_release_nightly_pytorch.yml
@@ -2,6 +2,24 @@ name: "Windows Release Nightly pytorch"
on:
workflow_dispatch:
+ inputs:
+ cu:
+ description: 'cuda version'
+ required: true
+ type: string
+ default: "121"
+
+ python_minor:
+ description: 'python minor version'
+ required: true
+ type: string
+ default: "12"
+
+ python_patch:
+ description: 'python patch version'
+ required: true
+ type: string
+ default: "1"
# push:
# branches:
# - master
@@ -20,21 +38,21 @@ jobs:
persist-credentials: false
- uses: actions/setup-python@v4
with:
- python-version: '3.11.6'
+ python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
- shell: bash
run: |
cd ..
cp -r ComfyUI ComfyUI_copy
- curl https://www.python.org/ftp/python/3.11.6/python-3.11.6-embed-amd64.zip -o python_embeded.zip
+ curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip
unzip python_embeded.zip -d python_embeded
cd python_embeded
- echo 'import site' >> ./python311._pth
+ echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
- python -m pip wheel torch torchvision torchaudio aiohttp==3.8.5 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
+ python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
- sed -i '1i../ComfyUI' ./python311._pth
+ sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
cd ..
git clone https://github.com/comfyanonymous/taesd
@@ -49,9 +67,10 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
- cp -r ComfyUI/.ci/nightly/update_windows/* ./update/
- cp -r ComfyUI/.ci/nightly/windows_base_files/* ./
+ echo "..\python_embeded\python.exe .\update.py ..\ComfyUI\\
+ ..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
+ pause" > ./update/update_comfyui_and_python_dependencies.bat
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
diff --git a/.vscode/settings.json b/.vscode/settings.json
deleted file mode 100644
index 202121e10..000000000
--- a/.vscode/settings.json
+++ /dev/null
@@ -1,9 +0,0 @@
-{
- "path-intellisense.mappings": {
- "../": "${workspaceFolder}/web/extensions/core"
- },
- "[python]": {
- "editor.defaultFormatter": "ms-python.autopep8"
- },
- "python.formatting.provider": "none"
-}
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py
index 625fae33f..ff1ca98c9 100644
--- a/comfy/cldm/cldm.py
+++ b/comfy/cldm/cldm.py
@@ -53,7 +53,7 @@ class ControlNet(nn.Module):
transformer_depth_middle=None,
transformer_depth_output=None,
device=None,
- operations=ops,
+ operations=ops.disable_weight_init,
**kwargs,
):
super().__init__()
@@ -141,24 +141,24 @@ class ControlNet(nn.Module):
)
]
)
- self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)])
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
self.input_hint_block = TimestepEmbedSequential(
- operations.conv_nd(dims, hint_channels, 16, 3, padding=1),
+ operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
- operations.conv_nd(dims, 16, 16, 3, padding=1),
+ operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
- operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2),
+ operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
- operations.conv_nd(dims, 32, 32, 3, padding=1),
+ operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
- operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2),
+ operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
- operations.conv_nd(dims, 96, 96, 3, padding=1),
+ operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
- operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2),
+ operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
- zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1))
+ operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
)
self._feature_size = model_channels
@@ -206,7 +206,7 @@ class ControlNet(nn.Module):
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
- self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
@@ -234,7 +234,7 @@ class ControlNet(nn.Module):
)
ch = out_ch
input_block_chans.append(ch)
- self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
ds *= 2
self._feature_size += ch
@@ -276,14 +276,14 @@ class ControlNet(nn.Module):
operations=operations
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
- self.middle_block_out = self.make_zero_conv(ch, operations=operations)
+ self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
self._feature_size += ch
- def make_zero_conv(self, channels, operations=None):
- return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
+ def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
+ return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
guided_hint = self.input_hint_block(hint, emb, context)
@@ -295,7 +295,7 @@ class ControlNet(nn.Module):
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
- h = x.type(self.dtype)
+ h = x
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None:
h = module(h, emb, context)
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index 58e74f62f..a9d23b446 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -55,13 +55,19 @@ fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
-parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
+fpunet_group = parser.add_mutually_exclusive_group()
+fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
+fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
+fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
+fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
+parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
+
fpte_group = parser.add_mutually_exclusive_group()
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
@@ -98,7 +104,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
-
+parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
diff --git a/comfy/clip_model.py b/comfy/clip_model.py
new file mode 100644
index 000000000..7397b7a26
--- /dev/null
+++ b/comfy/clip_model.py
@@ -0,0 +1,188 @@
+import torch
+from comfy.ldm.modules.attention import optimized_attention_for_device
+
+class CLIPAttention(torch.nn.Module):
+ def __init__(self, embed_dim, heads, dtype, device, operations):
+ super().__init__()
+
+ self.heads = heads
+ self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
+ self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
+ self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
+
+ self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
+
+ def forward(self, x, mask=None, optimized_attention=None):
+ q = self.q_proj(x)
+ k = self.k_proj(x)
+ v = self.v_proj(x)
+
+ out = optimized_attention(q, k, v, self.heads, mask)
+ return self.out_proj(out)
+
+ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
+ "gelu": torch.nn.functional.gelu,
+}
+
+class CLIPMLP(torch.nn.Module):
+ def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
+ super().__init__()
+ self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
+ self.activation = ACTIVATIONS[activation]
+ self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.activation(x)
+ x = self.fc2(x)
+ return x
+
+class CLIPLayer(torch.nn.Module):
+ def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
+ super().__init__()
+ self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
+ self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
+ self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
+ self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
+
+ def forward(self, x, mask=None, optimized_attention=None):
+ x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
+ x += self.mlp(self.layer_norm2(x))
+ return x
+
+
+class CLIPEncoder(torch.nn.Module):
+ def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
+ super().__init__()
+ self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
+
+ def forward(self, x, mask=None, intermediate_output=None):
+ optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None)
+
+ if intermediate_output is not None:
+ if intermediate_output < 0:
+ intermediate_output = len(self.layers) + intermediate_output
+
+ intermediate = None
+ for i, l in enumerate(self.layers):
+ x = l(x, mask, optimized_attention)
+ if i == intermediate_output:
+ intermediate = x.clone()
+ return x, intermediate
+
+class CLIPEmbeddings(torch.nn.Module):
+ def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
+ super().__init__()
+ self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
+ self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
+
+ def forward(self, input_tokens):
+ return self.token_embedding(input_tokens) + self.position_embedding.weight
+
+
+class CLIPTextModel_(torch.nn.Module):
+ def __init__(self, config_dict, dtype, device, operations):
+ num_layers = config_dict["num_hidden_layers"]
+ embed_dim = config_dict["hidden_size"]
+ heads = config_dict["num_attention_heads"]
+ intermediate_size = config_dict["intermediate_size"]
+ intermediate_activation = config_dict["hidden_act"]
+
+ super().__init__()
+ self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
+ self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
+ self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
+
+ def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
+ x = self.embeddings(input_tokens)
+ mask = None
+ if attention_mask is not None:
+ mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
+ mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
+
+ causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
+ if mask is not None:
+ mask += causal_mask
+ else:
+ mask = causal_mask
+
+ x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
+ x = self.final_layer_norm(x)
+ if i is not None and final_layer_norm_intermediate:
+ i = self.final_layer_norm(i)
+
+ pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
+ return x, i, pooled_output
+
+class CLIPTextModel(torch.nn.Module):
+ def __init__(self, config_dict, dtype, device, operations):
+ super().__init__()
+ self.num_layers = config_dict["num_hidden_layers"]
+ self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
+ self.dtype = dtype
+
+ def get_input_embeddings(self):
+ return self.text_model.embeddings.token_embedding
+
+ def set_input_embeddings(self, embeddings):
+ self.text_model.embeddings.token_embedding = embeddings
+
+ def forward(self, *args, **kwargs):
+ return self.text_model(*args, **kwargs)
+
+class CLIPVisionEmbeddings(torch.nn.Module):
+ def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
+
+ self.patch_embedding = operations.Conv2d(
+ in_channels=num_channels,
+ out_channels=embed_dim,
+ kernel_size=patch_size,
+ stride=patch_size,
+ bias=False,
+ dtype=dtype,
+ device=device
+ )
+
+ num_patches = (image_size // patch_size) ** 2
+ num_positions = num_patches + 1
+ self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
+
+ def forward(self, pixel_values):
+ embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
+ return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device)
+
+
+class CLIPVision(torch.nn.Module):
+ def __init__(self, config_dict, dtype, device, operations):
+ super().__init__()
+ num_layers = config_dict["num_hidden_layers"]
+ embed_dim = config_dict["hidden_size"]
+ heads = config_dict["num_attention_heads"]
+ intermediate_size = config_dict["intermediate_size"]
+ intermediate_activation = config_dict["hidden_act"]
+
+ self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=torch.float32, device=device, operations=operations)
+ self.pre_layrnorm = operations.LayerNorm(embed_dim)
+ self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
+ self.post_layernorm = operations.LayerNorm(embed_dim)
+
+ def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
+ x = self.embeddings(pixel_values)
+ x = self.pre_layrnorm(x)
+ #TODO: attention_mask?
+ x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
+ pooled_output = self.post_layernorm(x[:, 0, :])
+ return x, i, pooled_output
+
+class CLIPVisionModelProjection(torch.nn.Module):
+ def __init__(self, config_dict, dtype, device, operations):
+ super().__init__()
+ self.vision_model = CLIPVision(config_dict, dtype, device, operations)
+ self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
+
+ def forward(self, *args, **kwargs):
+ x = self.vision_model(*args, **kwargs)
+ out = self.visual_projection(x[2])
+ return (x[0], x[1], out)
diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py
index 1b109ff45..d3fdd5223 100644
--- a/comfy/clip_vision.py
+++ b/comfy/clip_vision.py
@@ -1,36 +1,43 @@
-from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils
from .utils import load_torch_file, transformers_convert
import os
import torch
-import contextlib
+import json
+
from . import ops
from . import model_patcher
from . import model_management
+from . import clip_model
+
+
+class Output:
+ def __getitem__(self, key):
+ return getattr(self, key)
+ def __setitem__(self, key, item):
+ setattr(self, key, item)
def clip_preprocess(image, size=224):
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
- scale = (size / min(image.shape[1], image.shape[2]))
- image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True)
- h = (image.shape[2] - size)//2
- w = (image.shape[3] - size)//2
- image = image[:,:,h:h+size,w:w+size]
+ image = image.movedim(-1, 1)
+ if not (image.shape[2] == size and image.shape[3] == size):
+ scale = (size / min(image.shape[2], image.shape[3]))
+ image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
+ h = (image.shape[2] - size)//2
+ w = (image.shape[3] - size)//2
+ image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
class ClipVisionModel():
def __init__(self, json_config):
- config = CLIPVisionConfig.from_json_file(json_config)
+ with open(json_config) as f:
+ config = json.load(f)
+
self.load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
- self.dtype = torch.float32
- if model_management.should_use_fp16(self.load_device, prioritize_performance=False):
- self.dtype = torch.float16
-
- with ops.use_comfy_ops(offload_device, self.dtype):
- with modeling_utils.no_init_weights():
- self.model = CLIPVisionModelWithProjection(config)
- self.model.to(self.dtype)
+ self.dtype = model_management.text_encoder_dtype(self.load_device)
+ self.model = clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, ops.manual_cast)
+ self.model.eval()
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
@@ -38,25 +45,13 @@ class ClipVisionModel():
def encode_image(self, image):
model_management.load_model_gpu(self.patcher)
- pixel_values = clip_preprocess(image.to(self.load_device))
-
- if self.dtype != torch.float32:
- precision_scope = torch.autocast
- else:
- precision_scope = lambda a, b: contextlib.nullcontext(a)
-
- with precision_scope(model_management.get_autocast_device(self.load_device), torch.float32):
- outputs = self.model(pixel_values=pixel_values, output_hidden_states=True)
-
- for k in outputs:
- t = outputs[k]
- if t is not None:
- if k == 'hidden_states':
- outputs["penultimate_hidden_states"] = t[-2].cpu()
- outputs["hidden_states"] = None
- else:
- outputs[k] = t.cpu()
+ pixel_values = clip_preprocess(image.to(self.load_device)).float()
+ out = self.model(pixel_values=pixel_values, intermediate_output=-2)
+ outputs = Output()
+ outputs["last_hidden_state"] = out[0].to(model_management.intermediate_device())
+ outputs["image_embeds"] = out[2].to(model_management.intermediate_device())
+ outputs["penultimate_hidden_states"] = out[1].to(model_management.intermediate_device())
return outputs
def convert_to_transformers(sd, prefix):
@@ -86,6 +81,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
if convert_keys:
sd = convert_to_transformers(sd, prefix)
if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
+ # todo: fix the importlib issue here
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py
index 825953552..59ba52cf9 100644
--- a/comfy/cmd/execution.py
+++ b/comfy/cmd/execution.py
@@ -13,6 +13,8 @@ import typing
from dataclasses import dataclass
from typing import Tuple
import sys
+import gc
+import inspect
import torch
@@ -442,6 +444,8 @@ class PromptExecutor:
for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x])
self.server.last_node_id = None
+ if model_management.DISABLE_SMART_MEMORY:
+ model_management.unload_all_models()
@@ -462,6 +466,14 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
errors = []
valid = True
+ # todo: investigate if these are at the right indent level
+ info = None
+ val = None
+
+ validate_function_inputs = []
+ if hasattr(obj_class, "VALIDATE_INPUTS"):
+ validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
+
for x in required_inputs:
if x not in inputs:
error = {
@@ -591,29 +603,7 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
errors.append(error)
continue
- if hasattr(obj_class, "VALIDATE_INPUTS"):
- input_data_all = get_input_data(inputs, obj_class, unique_id)
- # ret = obj_class.VALIDATE_INPUTS(**input_data_all)
- ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
- for i, r3 in enumerate(ret):
- if r3 is not True:
- details = f"{x}"
- if r3 is not False:
- details += f" - {str(r3)}"
-
- error = {
- "type": "custom_validation_failed",
- "message": "Custom validation failed for node",
- "details": details,
- "extra_info": {
- "input_name": x,
- "input_config": info,
- "received_value": val,
- }
- }
- errors.append(error)
- continue
- else:
+ if x not in validate_function_inputs:
if isinstance(type_input, list):
if val not in type_input:
input_config = info
@@ -640,6 +630,35 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
errors.append(error)
continue
+ if len(validate_function_inputs) > 0:
+ input_data_all = get_input_data(inputs, obj_class, unique_id)
+ input_filtered = {}
+ for x in input_data_all:
+ if x in validate_function_inputs:
+ input_filtered[x] = input_data_all[x]
+
+ #ret = obj_class.VALIDATE_INPUTS(**input_filtered)
+ ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
+ for x in input_filtered:
+ for i, r in enumerate(ret):
+ if r is not True:
+ details = f"{x}"
+ if r is not False:
+ details += f" - {str(r)}"
+
+ error = {
+ "type": "custom_validation_failed",
+ "message": "Custom validation failed for node",
+ "details": details,
+ "extra_info": {
+ "input_name": x,
+ "input_config": info,
+ "received_value": val,
+ }
+ }
+ errors.append(error)
+ continue
+
if len(errors) > 0 or valid is not True:
ret = (False, errors, unique_id)
else:
@@ -771,7 +790,7 @@ class PromptQueue:
self.server.queue_updated()
self.not_empty.notify()
- def get(self, timeout=None) -> typing.Tuple[QueueTuple, int]:
+ def get(self, timeout=None) -> typing.Optional[typing.Tuple[QueueTuple, int]]:
with self.not_empty:
while len(self.queue) == 0:
self.not_empty.wait(timeout=timeout)
diff --git a/comfy/cmd/folder_paths.py b/comfy/cmd/folder_paths.py
index c8ad20621..eed7577dd 100644
--- a/comfy/cmd/folder_paths.py
+++ b/comfy/cmd/folder_paths.py
@@ -188,8 +188,7 @@ def cached_filename_list_(folder_name):
if folder_name not in filename_list_cache:
return None
out = filename_list_cache[folder_name]
- if time.perf_counter() < (out[2] + 0.5):
- return out
+
for x in out[1]:
time_modified = out[1][x]
folder = x
diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py
index bdf77a039..38abd96fe 100644
--- a/comfy/cmd/main.py
+++ b/comfy/cmd/main.py
@@ -23,9 +23,9 @@ def execute_prestartup_script():
return False
node_paths = folder_paths.get_folder_paths("custom_nodes")
+ node_prestartup_times = []
for custom_node_path in node_paths:
possible_modules = os.listdir(custom_node_path) if os.path.exists(custom_node_path) else []
- node_prestartup_times = []
for possible_module in possible_modules:
module_path = os.path.join(custom_node_path, possible_module)
@@ -69,6 +69,10 @@ if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
print("Set cuda device to:", args.cuda_device)
+if args.deterministic:
+ if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
+
from .. import utils
import yaml
@@ -78,12 +82,12 @@ from .server import BinaryEventTypes
from .. import model_management
-def prompt_worker(q: execution.PromptQueue, _server: server_module.PromptServer):
+def prompt_worker(q, _server):
e = execution.PromptExecutor(_server)
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
-
+ current_time = 0.0
while True:
timeout = None
if need_gc:
@@ -94,11 +98,13 @@ def prompt_worker(q: execution.PromptQueue, _server: server_module.PromptServer)
item, item_id = queue_item
execution_start_time = time.perf_counter()
prompt_id = item[1]
+ _server.last_prompt_id = prompt_id
+
e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True
q.task_done(item_id, e.outputs_ui)
if _server.client_id is not None:
- _server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, _server.client_id)
+ _server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, _server.client_id)
current_time = time.perf_counter()
execution_time = current_time - execution_start_time
@@ -119,7 +125,10 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
def hijack_progress(server):
def hook(value, total, preview_image):
- server.send_sync("progress", {"value": value, "max": total}, server.client_id)
+ model_management.throw_exception_if_processing_interrupted()
+ progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
+
+ server.send_sync("progress", progress, server.client_id)
if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
@@ -204,7 +213,7 @@ def main():
print(f"Setting output directory to: {output_dir}")
folder_paths.set_output_directory(output_dir)
- #These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
+ # These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py
index 78407adeb..06a621591 100644
--- a/comfy/cmd/server.py
+++ b/comfy/cmd/server.py
@@ -734,7 +734,8 @@ class PromptServer():
message = self.encode_bytes(event, data)
if sid is None:
- for ws in self.sockets.values():
+ sockets = list(self.sockets.values())
+ for ws in sockets:
await send_socket_catch_exception(ws.send_bytes, message)
elif sid in self.sockets:
await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
@@ -743,7 +744,8 @@ class PromptServer():
message = {"type": event, "data": data}
if sid is None:
- for ws in self.sockets.values():
+ sockets = list(self.sockets.values())
+ for ws in sockets:
await send_socket_catch_exception(ws.send_json, message)
elif sid in self.sockets:
await send_socket_catch_exception(self.sockets[sid].send_json, message)
diff --git a/comfy/controlnet.py b/comfy/controlnet.py
index 919dfbea3..0b11ad1f7 100644
--- a/comfy/controlnet.py
+++ b/comfy/controlnet.py
@@ -1,10 +1,13 @@
import torch
import math
import os
+import contextlib
+
from . import utils
from . import model_management
from . import model_detection
from . import model_patcher
+from . import ops
from .cldm import cldm
from .t2i_adapter import adapter
@@ -34,13 +37,13 @@ class ControlBase:
self.cond_hint = None
self.strength = 1.0
self.timestep_percent_range = (0.0, 1.0)
+ self.global_average_pooling = False
self.timestep_range = None
if device is None:
device = model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
- self.global_average_pooling = False
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
self.cond_hint_original = cond_hint
@@ -75,6 +78,7 @@ class ControlBase:
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range
+ c.global_average_pooling = self.global_average_pooling
def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
@@ -127,12 +131,14 @@ class ControlBase:
return out
class ControlNet(ControlBase):
- def __init__(self, control_model, global_average_pooling=False, device=None):
+ def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
super().__init__(device)
self.control_model = control_model
- self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
+ self.load_device = load_device
+ self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling
self.model_sampling_current = None
+ self.manual_cast_dtype = manual_cast_dtype
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
@@ -146,28 +152,31 @@ class ControlNet(ControlBase):
else:
return None
+ dtype = self.control_model.dtype
+ if self.manual_cast_dtype is not None:
+ dtype = self.manual_cast_dtype
+
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
- self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
+ self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
-
context = cond['c_crossattn']
y = cond.get('y', None)
if y is not None:
- y = y.to(self.control_model.dtype)
+ y = y.to(dtype)
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
- control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
+ control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype)
def copy(self):
- c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
+ c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
self.copy_to(c)
return c
@@ -198,10 +207,11 @@ class ControlLoraOps:
self.bias = None
def forward(self, input):
+ weight, bias = ops.cast_bias_weight(self, input)
if self.up is not None:
- return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
+ return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else:
- return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
+ return torch.nn.functional.linear(input, weight, bias)
class Conv2d(torch.nn.Module):
def __init__(
@@ -237,16 +247,11 @@ class ControlLoraOps:
def forward(self, input):
+ weight, bias = ops.cast_bias_weight(self, input)
if self.up is not None:
- return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
+ return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else:
- return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
-
- def conv_nd(self, dims, *args, **kwargs):
- if dims == 2:
- return self.Conv2d(*args, **kwargs)
- else:
- raise ValueError(f"unsupported dimensions: {dims}")
+ return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
class ControlLora(ControlNet):
@@ -260,17 +265,26 @@ class ControlLora(ControlNet):
controlnet_config = model.model_config.unet_config.copy()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
- controlnet_config["operations"] = ControlLoraOps()
- self.control_model = cldm.ControlNet(**controlnet_config)
+ self.manual_cast_dtype = model.manual_cast_dtype
dtype = model.get_dtype()
- self.control_model.to(dtype)
+ if self.manual_cast_dtype is None:
+ class control_lora_ops(ControlLoraOps, ops.disable_weight_init):
+ pass
+ else:
+ class control_lora_ops(ControlLoraOps, ops.manual_cast):
+ pass
+ dtype = self.manual_cast_dtype
+
+ controlnet_config["operations"] = control_lora_ops
+ controlnet_config["dtype"] = dtype
+ self.control_model = cldm.ControlNet(**controlnet_config)
self.control_model.to(model_management.get_torch_device())
diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict()
cm = self.control_model.state_dict()
for k in sd:
- weight = model_management.resolve_lowvram_weight(sd[k], diffusion_model, k)
+ weight = sd[k]
try:
utils.set_attr(self.control_model, k, weight)
except:
@@ -367,6 +381,10 @@ def load_controlnet(ckpt_path, model=None):
if controlnet_config is None:
unet_dtype = model_management.unet_dtype()
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
+ load_device = model_management.get_torch_device()
+ manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
+ if manual_cast_dtype is not None:
+ controlnet_config["operations"] = ops.manual_cast
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = cldm.ControlNet(**controlnet_config)
@@ -395,14 +413,12 @@ def load_controlnet(ckpt_path, model=None):
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
print(missing, unexpected)
- control_model = control_model.to(unet_dtype)
-
global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
- control = ControlNet(control_model, global_average_pooling=global_average_pooling)
+ control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
class T2IAdapter(ControlBase):
diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py
index c209087e0..2252a075e 100644
--- a/comfy/latent_formats.py
+++ b/comfy/latent_formats.py
@@ -33,3 +33,7 @@ class SDXL(LatentFormat):
[-0.3112, -0.2359, -0.2076]
]
self.taesd_decoder_name = "taesdxl_decoder"
+
+class SD_X4(LatentFormat):
+ def __init__(self):
+ self.scale_factor = 0.08333
diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py
index 0f9ac17cd..f85c3369a 100644
--- a/comfy/ldm/models/autoencoder.py
+++ b/comfy/ldm/models/autoencoder.py
@@ -9,6 +9,7 @@ from ..modules.distributions.distributions import DiagonalGaussianDistribution
from ..util import instantiate_from_config, get_obj_from_str
from ..modules.ema import LitEma
+from ... import ops
class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = True):
@@ -162,12 +163,12 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
},
**kwargs,
)
- self.quant_conv = torch.nn.Conv2d(
+ self.quant_conv = ops.disable_weight_init.Conv2d(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim,
1,
)
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.post_quant_conv = ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
def get_autoencoder_params(self) -> list:
diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py
index 97d03a589..17dda7c58 100644
--- a/comfy/ldm/modules/attention.py
+++ b/comfy/ldm/modules/attention.py
@@ -18,6 +18,7 @@ if model_management.xformers_enabled():
from ...cli_args import args
from ... import ops
+ops = ops.disable_weight_init
# CrossAttn precision handling
if args.dont_upcast_attention:
@@ -82,16 +83,6 @@ class FeedForward(nn.Module):
def forward(self, x):
return self.net(x)
-
-def zero_module(module):
- """
- Zero out the parameters of a module and return it.
- """
- for p in module.parameters():
- p.detach().zero_()
- return module
-
-
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
@@ -112,19 +103,20 @@ def attention_basic(q, k, v, heads, mask=None):
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
- with torch.autocast(enabled=False, device_type = 'cuda'):
- q, k = q.float(), k.float()
- sim = einsum('b i d, b j d -> b i j', q, k) * scale
+ sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
del q, k
if exists(mask):
- mask = rearrange(mask, 'b ... -> b (...)')
- max_neg_value = -torch.finfo(sim.dtype).max
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
- sim.masked_fill_(~mask, max_neg_value)
+ if mask.dtype == torch.bool:
+ mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+ else:
+ sim += mask
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
@@ -349,6 +341,18 @@ else:
if model_management.pytorch_attention_enabled():
optimized_attention_masked = attention_pytorch
+def optimized_attention_for_device(device, mask=False):
+ if device == torch.device("cpu"): #TODO
+ if model_management.pytorch_attention_enabled():
+ return attention_pytorch
+ else:
+ return attention_basic
+ if mask:
+ return optimized_attention_masked
+
+ return optimized_attention
+
+
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
super().__init__()
@@ -393,7 +397,7 @@ class BasicTransformerBlock(nn.Module):
self.is_res = inner_dim == dim
if self.ff_in:
- self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device)
+ self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.disable_self_attn = disable_self_attn
@@ -413,10 +417,10 @@ class BasicTransformerBlock(nn.Module):
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
- self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
+ self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
- self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
- self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
+ self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
+ self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.checkpoint = checkpoint
self.n_heads = n_heads
self.d_head = d_head
@@ -558,7 +562,7 @@ class SpatialTransformer(nn.Module):
context_dim = [context_dim] * depth
self.in_channels = in_channels
inner_dim = n_heads * d_head
- self.norm = Normalize(in_channels, dtype=dtype, device=device)
+ self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
if not use_linear:
self.proj_in = operations.Conv2d(in_channels,
inner_dim,
diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py
index 943a3dd1a..7d66ef4c3 100644
--- a/comfy/ldm/modules/diffusionmodules/model.py
+++ b/comfy/ldm/modules/diffusionmodules/model.py
@@ -8,6 +8,7 @@ from typing import Optional, Any
from .... import model_management
from .... import ops
+ops = ops.disable_weight_init
if model_management.xformers_enabled_vae():
import xformers
@@ -40,7 +41,7 @@ def nonlinearity(x):
def Normalize(in_channels, num_groups=32):
- return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):
diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py
index 222435ba0..903dc2801 100644
--- a/comfy/ldm/modules/diffusionmodules/openaimodel.py
+++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py
@@ -12,13 +12,13 @@ from .util import (
checkpoint,
avg_pool_nd,
zero_module,
- normalization,
timestep_embedding,
AlphaBlender,
)
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from ...util import exists
from .... import ops
+ops = ops.disable_weight_init
class TimestepBlock(nn.Module):
"""
@@ -177,7 +177,7 @@ class ResBlock(TimestepBlock):
padding = kernel_size // 2
self.in_layers = nn.Sequential(
- nn.GroupNorm(32, channels, dtype=dtype, device=device),
+ operations.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
)
@@ -206,12 +206,11 @@ class ResBlock(TimestepBlock):
),
)
self.out_layers = nn.Sequential(
- nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
+ operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
nn.SiLU(),
nn.Dropout(p=dropout),
- zero_module(
- operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
- ),
+ operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
+ ,
)
if self.out_channels == channels:
@@ -438,9 +437,6 @@ class UNetModel(nn.Module):
operations=ops,
):
super().__init__()
- assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
- if use_spatial_transformer:
- assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
@@ -457,7 +453,6 @@ class UNetModel(nn.Module):
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
- self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
@@ -503,7 +498,7 @@ class UNetModel(nn.Module):
if self.num_classes is not None:
if isinstance(self.num_classes, int):
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
@@ -810,13 +805,13 @@ class UNetModel(nn.Module):
self._feature_size += ch
self.out = nn.Sequential(
- nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
+ operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
nn.SiLU(),
zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
- nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
+ operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
@@ -842,14 +837,14 @@ class UNetModel(nn.Module):
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
- h = x.type(self.dtype)
+ h = x
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
diff --git a/comfy/ldm/modules/diffusionmodules/upscaling.py b/comfy/ldm/modules/diffusionmodules/upscaling.py
index 7a86adeb4..9652cfc76 100644
--- a/comfy/ldm/modules/diffusionmodules/upscaling.py
+++ b/comfy/ldm/modules/diffusionmodules/upscaling.py
@@ -41,10 +41,14 @@ class AbstractLowScaleModel(nn.Module):
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
- def q_sample(self, x_start, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x_start))
- return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+ def q_sample(self, x_start, t, noise=None, seed=None):
+ if noise is None:
+ if seed is None:
+ noise = torch.randn_like(x_start)
+ else:
+ noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device)
+ return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise)
def forward(self, x):
return x, None
@@ -69,12 +73,12 @@ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
super().__init__(noise_schedule_config=noise_schedule_config)
self.max_noise_level = max_noise_level
- def forward(self, x, noise_level=None):
+ def forward(self, x, noise_level=None, seed=None):
if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else:
assert isinstance(noise_level, torch.Tensor)
- z = self.q_sample(x, noise_level)
+ z = self.q_sample(x, noise_level, seed=seed)
return z, noise_level
diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py
index 85cf433c3..88346dc2e 100644
--- a/comfy/ldm/modules/diffusionmodules/util.py
+++ b/comfy/ldm/modules/diffusionmodules/util.py
@@ -16,7 +16,6 @@ import numpy as np
from einops import repeat, rearrange
from ...util import instantiate_from_config
-from .... import ops
class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"]
@@ -52,9 +51,9 @@ class AlphaBlender(nn.Module):
if self.merge_strategy == "fixed":
# make shape compatible
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
- alpha = self.mix_factor
+ alpha = self.mix_factor.to(image_only_indicator.device)
elif self.merge_strategy == "learned":
- alpha = torch.sigmoid(self.mix_factor)
+ alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device))
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
elif self.merge_strategy == "learned_with_images":
@@ -62,7 +61,7 @@ class AlphaBlender(nn.Module):
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
- rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
+ rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"),
)
alpha = rearrange(alpha, self.rearrange_pattern)
# make shape compatible
@@ -273,46 +272,6 @@ def mean_flat(tensor):
return tensor.mean(dim=list(range(1, len(tensor.shape))))
-def normalization(channels, dtype=None):
- """
- Make a standard normalization layer.
- :param channels: number of input channels.
- :return: an nn.Module for normalization.
- """
- return GroupNorm32(32, channels, dtype=dtype)
-
-
-# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
-class SiLU(nn.Module):
- def forward(self, x):
- return x * torch.sigmoid(x)
-
-
-class GroupNorm32(nn.GroupNorm):
- def forward(self, x):
- return super().forward(x.float()).type(x.dtype)
-
-
-def conv_nd(dims, *args, **kwargs):
- """
- Create a 1D, 2D, or 3D convolution module.
- """
- if dims == 1:
- return nn.Conv1d(*args, **kwargs)
- elif dims == 2:
- return ops.Conv2d(*args, **kwargs)
- elif dims == 3:
- return nn.Conv3d(*args, **kwargs)
- raise ValueError(f"unsupported dimensions: {dims}")
-
-
-def linear(*args, **kwargs):
- """
- Create a linear module.
- """
- return ops.Linear(*args, **kwargs)
-
-
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
diff --git a/comfy/ldm/modules/encoders/noise_aug_modules.py b/comfy/ldm/modules/encoders/noise_aug_modules.py
index b59bf204b..66767b587 100644
--- a/comfy/ldm/modules/encoders/noise_aug_modules.py
+++ b/comfy/ldm/modules/encoders/noise_aug_modules.py
@@ -15,12 +15,12 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
def scale(self, x):
# re-normalize to centered mean and unit variance
- x = (x - self.data_mean) * 1. / self.data_std
+ x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device)
return x
def unscale(self, x):
# back to original data stats
- x = (x * self.data_std) + self.data_mean
+ x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
return x
def forward(self, x, noise_level=None):
diff --git a/comfy/ldm/modules/temporal_ae.py b/comfy/ldm/modules/temporal_ae.py
index 04cb7eb00..c1966ea42 100644
--- a/comfy/ldm/modules/temporal_ae.py
+++ b/comfy/ldm/modules/temporal_ae.py
@@ -5,6 +5,7 @@ import torch
from einops import rearrange, repeat
from ... import ops
+ops = ops.disable_weight_init
from .diffusionmodules.model import (
AttnBlock,
@@ -81,14 +82,14 @@ class VideoResBlock(ResnetBlock):
x = self.time_stack(x, temb)
- alpha = self.get_alpha(bs=b // timesteps)
+ alpha = self.get_alpha(bs=b // timesteps).to(x.device)
x = alpha * x + (1.0 - alpha) * x_mix
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
-class AE3DConv(torch.nn.Conv2d):
+class AE3DConv(ops.Conv2d):
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
super().__init__(in_channels, out_channels, *args, **kwargs)
if isinstance(video_kernel_size, Iterable):
@@ -96,7 +97,7 @@ class AE3DConv(torch.nn.Conv2d):
else:
padding = int(video_kernel_size // 2)
- self.time_mix_conv = torch.nn.Conv3d(
+ self.time_mix_conv = ops.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=video_kernel_size,
@@ -166,7 +167,7 @@ class AttnVideoBlock(AttnBlock):
emb = emb[:, None, :]
x_mix = x_mix + emb
- alpha = self.get_alpha()
+ alpha = self.get_alpha().to(x.device)
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
diff --git a/comfy/lora.py b/comfy/lora.py
index 78e5d4ce8..a19e7161d 100644
--- a/comfy/lora.py
+++ b/comfy/lora.py
@@ -43,7 +43,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
- patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
+ patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
loaded_keys.add(A_name)
loaded_keys.add(B_name)
@@ -64,7 +64,7 @@ def load_lora(lora, to_load):
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)
- patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)
+ patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
@@ -116,8 +116,19 @@ def load_lora(lora, to_load):
loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
- patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
+ patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
+ #glora
+ a1_name = "{}.a1.weight".format(x)
+ a2_name = "{}.a2.weight".format(x)
+ b1_name = "{}.b1.weight".format(x)
+ b2_name = "{}.b2.weight".format(x)
+ if a1_name in lora:
+ patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
+ loaded_keys.add(a1_name)
+ loaded_keys.add(a2_name)
+ loaded_keys.add(b1_name)
+ loaded_keys.add(b2_name)
w_norm_name = "{}.w_norm".format(x)
b_norm_name = "{}.b_norm".format(x)
@@ -126,21 +137,21 @@ def load_lora(lora, to_load):
if w_norm is not None:
loaded_keys.add(w_norm_name)
- patch_dict[to_load[x]] = (w_norm,)
+ patch_dict[to_load[x]] = ("diff", (w_norm,))
if b_norm is not None:
loaded_keys.add(b_norm_name)
- patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
+ patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
diff_name = "{}.diff".format(x)
diff_weight = lora.get(diff_name, None)
if diff_weight is not None:
- patch_dict[to_load[x]] = (diff_weight,)
+ patch_dict[to_load[x]] = ("diff", (diff_weight,))
loaded_keys.add(diff_name)
diff_bias_name = "{}.diff_b".format(x)
diff_bias = lora.get(diff_bias_name, None)
if diff_bias is not None:
- patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
+ patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
loaded_keys.add(diff_bias_name)
for x in lora.keys():
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 38bdfe21b..7a0802e2b 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -1,10 +1,12 @@
import torch
-from .ldm.modules.diffusionmodules.openaimodel import UNetModel
+from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
-from .ldm.modules.diffusionmodules.openaimodel import Timestep
+from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from . import model_management
from . import conds
+from . import ops
from enum import Enum
+import contextlib
from . import utils
class ModelType(Enum):
@@ -13,7 +15,7 @@ class ModelType(Enum):
V_PREDICTION_EDM = 3
-from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
+from .model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
def model_sampling(model_config, model_type):
@@ -40,9 +42,14 @@ class BaseModel(torch.nn.Module):
unet_config = model_config.unet_config
self.latent_format = model_config.latent_format
self.model_config = model_config
+ self.manual_cast_dtype = model_config.manual_cast_dtype
if not unet_config.get("disable_unet_model_creation", False):
- self.diffusion_model = UNetModel(**unet_config, device=device)
+ if self.manual_cast_dtype is not None:
+ operations = ops.manual_cast
+ else:
+ operations = ops.disable_weight_init
+ self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
@@ -61,15 +68,21 @@ class BaseModel(torch.nn.Module):
context = c_crossattn
dtype = self.get_dtype()
+
+ if self.manual_cast_dtype is not None:
+ dtype = self.manual_cast_dtype
+
xc = xc.to(dtype)
t = self.model_sampling.timestep(t).float()
context = context.to(dtype)
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
- if hasattr(extra, "to"):
- extra = extra.to(dtype)
+ if hasattr(extra, "dtype"):
+ if extra.dtype != torch.int and extra.dtype != torch.long:
+ extra = extra.to(dtype)
extra_conds[o] = extra
+
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
@@ -117,6 +130,10 @@ class BaseModel(torch.nn.Module):
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = conds.CONDRegular(adm)
+
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
return out
def load_model_weights(self, sd, unet_prefix=""):
@@ -144,11 +161,7 @@ class BaseModel(torch.nn.Module):
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
- unet_sd = self.diffusion_model.state_dict()
- unet_state_dict = {}
- for k in unet_sd:
- unet_state_dict[k] = model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
-
+ unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16:
@@ -165,9 +178,12 @@ class BaseModel(torch.nn.Module):
def memory_required(self, input_shape):
if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention():
+ dtype = self.get_dtype()
+ if self.manual_cast_dtype is not None:
+ dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3]
- return (area * model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024)
+ return (area * model_management.dtype_size(dtype) / 50) * (1024 * 1024)
else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * input_shape[2] * input_shape[3]
@@ -307,9 +323,75 @@ class SVD_img2vid(BaseModel):
out['c_concat'] = conds.CONDNoiseShape(latent_image)
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
+
if "time_conditioning" in kwargs:
out["time_context"] = conds.CONDCrossAttn(kwargs["time_conditioning"])
out['image_only_indicator'] = conds.CONDConstant(torch.zeros((1,), device=device))
out['num_video_frames'] = conds.CONDConstant(noise.shape[0])
return out
+
+class Stable_Zero123(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
+ super().__init__(model_config, model_type, device=device)
+ self.cc_projection = ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device)
+ self.cc_projection.weight.copy_(cc_projection_weight)
+ self.cc_projection.bias.copy_(cc_projection_bias)
+
+ def extra_conds(self, **kwargs):
+ out = {}
+
+ latent_image = kwargs.get("concat_latent_image", None)
+ noise = kwargs.get("noise", None)
+
+ if latent_image is None:
+ latent_image = torch.zeros_like(noise)
+
+ if latent_image.shape[1:] != noise.shape[1:]:
+ latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
+
+ latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])
+
+ out['c_concat'] = conds.CONDNoiseShape(latent_image)
+
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ if cross_attn.shape[-1] != 768:
+ cross_attn = self.cc_projection(cross_attn)
+ out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
+ return out
+
+class SD_X4Upscaler(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
+ super().__init__(model_config, model_type, device=device)
+ self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350)
+
+ def extra_conds(self, **kwargs):
+ out = {}
+
+ image = kwargs.get("concat_image", None)
+ noise = kwargs.get("noise", None)
+ noise_augment = kwargs.get("noise_augmentation", 0.0)
+ device = kwargs["device"]
+ seed = kwargs["seed"] - 10
+
+ noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment)
+
+ if image is None:
+ image = torch.zeros_like(noise)[:,:3]
+
+ if image.shape[1:] != noise.shape[1:]:
+ image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
+
+ noise_level = torch.tensor([noise_level], device=device)
+ if noise_augment > 0:
+ image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed)
+
+ image = utils.resize_to_batch_size(image, noise.shape[0])
+
+ out['c_concat'] = conds.CONDNoiseShape(image)
+ out['y'] = conds.CONDRegular(noise_level)
+ return out
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index d08941bb9..13ac3bdc4 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -34,7 +34,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
unet_config = {
"use_checkpoint": False,
"image_size": 32,
- "out_channels": 4,
"use_spatial_transformer": True,
"legacy": False
}
@@ -50,6 +49,12 @@ def detect_unet_config(state_dict, key_prefix, dtype):
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
+ out_key = '{}out.2.weight'.format(key_prefix)
+ if out_key in state_dict:
+ out_channels = state_dict[out_key].shape[0]
+ else:
+ out_channels = 4
+
num_res_blocks = []
channel_mult = []
attention_resolutions = []
@@ -122,6 +127,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
transformer_depth_middle = -1
unet_config["in_channels"] = in_channels
+ unet_config["out_channels"] = out_channels
unet_config["model_channels"] = model_channels
unet_config["num_res_blocks"] = num_res_blocks
unet_config["transformer_depth"] = transformer_depth
@@ -289,7 +295,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
- supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B]
+ Segmind_Vega = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
+ 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
+ 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 1, 1, 2, 2], 'transformer_depth_output': [0, 0, 0, 1, 1, 1, 2, 2, 2],
+ 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
+ 'use_temporal_attention': False, 'use_temporal_resblock': False}
+
+ supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega]
for unet_config in supported_models:
matches = True
diff --git a/comfy/model_management.py b/comfy/model_management.py
index 068e9d162..6d6800e86 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -2,6 +2,7 @@ import psutil
from enum import Enum
from .cli_args import args
from . import utils
+
import torch
import sys
@@ -28,6 +29,10 @@ total_vram = 0
lowvram_available = True
xpu_available = False
+if args.deterministic:
+ print("Using deterministic algorithms for pytorch")
+ torch.use_deterministic_algorithms(True, warn_only=True)
+
directml_enabled = False
if args.directml is not None:
import torch_directml
@@ -182,6 +187,9 @@ except:
if is_intel_xpu():
VAE_DTYPE = torch.bfloat16
+if args.cpu_vae:
+ VAE_DTYPE = torch.float32
+
if args.fp16_vae:
VAE_DTYPE = torch.float16
elif args.bf16_vae:
@@ -214,15 +222,8 @@ if args.force_fp16 or cpu_state == CPUState.MPS:
FORCE_FP16 = True
if lowvram_available:
- try:
- import accelerate
- if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
- vram_state = set_vram_to
- except Exception as e:
- import traceback
- print(traceback.format_exc())
- print("ERROR: LOW VRAM MODE NEEDS accelerate.")
- lowvram_available = False
+ if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
+ vram_state = set_vram_to
if cpu_state != CPUState.GPU:
@@ -262,6 +263,14 @@ print("VAE dtype:", VAE_DTYPE)
current_loaded_models = []
+def module_size(module):
+ module_mem = 0
+ sd = module.state_dict()
+ for k in sd:
+ t = sd[k]
+ module_mem += t.nelement() * t.element_size()
+ return module_mem
+
class LoadedModel:
def __init__(self, model):
self.model = model
@@ -294,8 +303,20 @@ class LoadedModel:
if lowvram_model_memory > 0:
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
- device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
- accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
+ mem_counter = 0
+ for m in self.real_model.modules():
+ if hasattr(m, "comfy_cast_weights"):
+ m.prev_comfy_cast_weights = m.comfy_cast_weights
+ m.comfy_cast_weights = True
+ module_mem = module_size(m)
+ if mem_counter + module_mem < lowvram_model_memory:
+ m.to(self.device)
+ mem_counter += module_mem
+ elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
+ m.to(self.device)
+ mem_counter += module_size(m)
+ print("lowvram: loaded module regularly", m)
+
self.model_accelerated = True
if is_intel_xpu() and not args.disable_ipex_optimize:
@@ -305,7 +326,11 @@ class LoadedModel:
def model_unload(self):
if self.model_accelerated:
- accelerate.hooks.remove_hook_from_submodules(self.real_model)
+ for m in self.real_model.modules():
+ if hasattr(m, "prev_comfy_cast_weights"):
+ m.comfy_cast_weights = m.prev_comfy_cast_weights
+ del m.prev_comfy_cast_weights
+
self.model_accelerated = False
self.model.unpatch_model(self.model.offload_device)
@@ -398,14 +423,14 @@ def load_models_gpu(models, memory_required=0):
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
- lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
+ lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM
else:
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
- lowvram_model_memory = 256 * 1024 * 1024
+ lowvram_model_memory = 64 * 1024 * 1024
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
current_loaded_models.insert(0, loaded_model)
@@ -430,6 +455,13 @@ def dtype_size(dtype):
dtype_size = 4
if dtype == torch.float16 or dtype == torch.bfloat16:
dtype_size = 2
+ elif dtype == torch.float32:
+ dtype_size = 4
+ else:
+ try:
+ dtype_size = dtype.itemsize
+ except: #Old pytorch doesn't have .itemsize
+ pass
return dtype_size
def unet_offload_device():
@@ -459,10 +491,30 @@ def unet_inital_load_device(parameters, dtype):
def unet_dtype(device=None, model_params=0):
if args.bf16_unet:
return torch.bfloat16
+ if args.fp16_unet:
+ return torch.float16
+ if args.fp8_e4m3fn_unet:
+ return torch.float8_e4m3fn
+ if args.fp8_e5m2_unet:
+ return torch.float8_e5m2
if should_use_fp16(device=device, model_params=model_params):
return torch.float16
return torch.float32
+# None means no manual cast
+def unet_manual_cast(weight_dtype, inference_device):
+ if weight_dtype == torch.float32:
+ return None
+
+ fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
+ if fp16_supported and weight_dtype == torch.float16:
+ return None
+
+ if fp16_supported:
+ return torch.float16
+ else:
+ return torch.float32
+
def text_encoder_offload_device():
if args.gpu_only:
return get_torch_device()
@@ -492,12 +544,23 @@ def text_encoder_dtype(device=None):
elif args.fp32_text_enc:
return torch.float32
+ if is_device_cpu(device):
+ return torch.float16
+
if should_use_fp16(device, prioritize_performance=False):
return torch.float16
else:
return torch.float32
+def intermediate_device():
+ if args.gpu_only:
+ return get_torch_device()
+ else:
+ return torch.device("cpu")
+
def vae_device():
+ if args.cpu_vae:
+ return torch.device("cpu")
return get_torch_device()
def vae_offload_device():
@@ -515,6 +578,22 @@ def get_autocast_device(dev):
return dev.type
return "cuda"
+def supports_dtype(device, dtype): #TODO
+ if dtype == torch.float32:
+ return True
+ if is_device_cpu(device):
+ return False
+ if dtype == torch.float16:
+ return True
+ if dtype == torch.bfloat16:
+ return True
+ return False
+
+def device_supports_non_blocking(device):
+ if is_device_mps(device):
+ return False #pytorch bug? mps doesn't support non blocking
+ return True
+
def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = False
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
@@ -525,15 +604,17 @@ def cast_to_device(tensor, device, dtype, copy=False):
elif is_intel_xpu():
device_supports_cast = True
+ non_blocking = device_supports_non_blocking(device)
+
if device_supports_cast:
if copy:
if tensor.device == device:
- return tensor.to(dtype, copy=copy)
- return tensor.to(device, copy=copy).to(dtype)
+ return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
+ return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
- return tensor.to(device).to(dtype)
+ return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
- return tensor.to(dtype).to(device, copy=copy)
+ return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
def xformers_enabled():
global directml_enabled
@@ -687,11 +768,11 @@ def soft_empty_cache(force=False):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
-def resolve_lowvram_weight(weight, model, key):
- if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
- key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
- op = utils.get_attr(model, '.'.join(key_split[:-1]))
- weight = op._hf_hook.weights_map[key_split[-1]]
+def unload_all_models():
+ free_memory(1e30, get_torch_device())
+
+
+def resolve_lowvram_weight(weight, model, key): #TODO: remove
return weight
#TODO: might be cleaner to put this somewhere else
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index c8fa1b358..814537171 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -28,13 +28,9 @@ class ModelPatcher:
if self.size > 0:
return self.size
model_sd = self.model.state_dict()
- size = 0
- for k in model_sd:
- t = model_sd[k]
- size += t.nelement() * t.element_size()
- self.size = size
+ self.size = model_management.module_size(self.model)
self.model_keys = set(model_sd.keys())
- return size
+ return self.size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
@@ -55,11 +51,18 @@ class ModelPatcher:
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)
- def set_model_sampler_cfg_function(self, sampler_cfg_function):
+ def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
+ if disable_cfg1_optimization:
+ self.model_options["disable_cfg1_optimization"] = True
+
+ def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
+ self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
+ if disable_cfg1_optimization:
+ self.model_options["disable_cfg1_optimization"] = True
def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function
@@ -70,13 +73,17 @@ class ModelPatcher:
to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch]
- def set_model_patch_replace(self, patch, name, block_name, number):
+ def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
to = self.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
- to["patches_replace"][name][(block_name, number)] = patch
+ if transformer_index is not None:
+ block = (block_name, number, transformer_index)
+ else:
+ block = (block_name, number)
+ to["patches_replace"][name][block] = patch
def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch")
@@ -84,11 +91,11 @@ class ModelPatcher:
def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch")
- def set_model_attn1_replace(self, patch, block_name, number):
- self.set_model_patch_replace(patch, "attn1", block_name, number)
+ def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
+ self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
- def set_model_attn2_replace(self, patch, block_name, number):
- self.set_model_patch_replace(patch, "attn2", block_name, number)
+ def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
+ self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
def set_model_attn1_output_patch(self, patch):
self.set_model_patch(patch, "attn1_output_patch")
@@ -167,40 +174,41 @@ class ModelPatcher:
sd.pop(k)
return sd
- def patch_model(self, device_to=None):
+ def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches:
old = getattr(self.model, k)
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
setattr(self.model, k, self.object_patches[k])
- model_sd = self.model_state_dict()
- for key in self.patches:
- if key not in model_sd:
- print("could not patch. key doesn't exist in model:", key)
- continue
+ if patch_weights:
+ model_sd = self.model_state_dict()
+ for key in self.patches:
+ if key not in model_sd:
+ print("could not patch. key doesn't exist in model:", key)
+ continue
- weight = model_sd[key]
+ weight = model_sd[key]
- inplace_update = self.weight_inplace_update
+ inplace_update = self.weight_inplace_update
- if key not in self.backup:
- self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
+ if key not in self.backup:
+ self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
+
+ if device_to is not None:
+ temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
+ else:
+ temp_weight = weight.to(torch.float32, copy=True)
+ out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
+ if inplace_update:
+ utils.copy_to_param(self.model, key, out_weight)
+ else:
+ utils.set_attr(self.model, key, out_weight)
+ del temp_weight
if device_to is not None:
- temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
- else:
- temp_weight = weight.to(torch.float32, copy=True)
- out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
- if inplace_update:
- utils.copy_to_param(self.model, key, out_weight)
- else:
- utils.set_attr(self.model, key, out_weight)
- del temp_weight
-
- if device_to is not None:
- self.model.to(device_to)
- self.current_device = device_to
+ self.model.to(device_to)
+ self.current_device = device_to
return self.model
@@ -217,13 +225,19 @@ class ModelPatcher:
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
if len(v) == 1:
+ patch_type = "diff"
+ elif len(v) == 2:
+ patch_type = v[0]
+ v = v[1]
+
+ if patch_type == "diff":
w1 = v[0]
if alpha != 0.0:
if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * model_management.cast_to_device(w1, weight.device, weight.dtype)
- elif len(v) == 4: #lora/locon
+ elif patch_type == "lora": #lora/locon
mat1 = model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = model_management.cast_to_device(v[1], weight.device, torch.float32)
if v[2] is not None:
@@ -237,7 +251,7 @@ class ModelPatcher:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
- elif len(v) == 8: #lokr
+ elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
@@ -276,7 +290,7 @@ class ModelPatcher:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
- else: #loha
+ elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
@@ -305,6 +319,18 @@ class ModelPatcher:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
+ elif patch_type == "glora":
+ if v[4] is not None:
+ alpha *= v[4] / v[0].shape[0]
+
+ a1 = model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
+ a2 = model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
+ b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
+ b2 = model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
+
+ weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
+ else:
+ print("patch type not recognized", patch_type, key)
return weight
diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py
index 69c8b1f01..cc8745c10 100644
--- a/comfy/model_sampling.py
+++ b/comfy/model_sampling.py
@@ -22,10 +22,17 @@ class V_PREDICTION(EPS):
class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
- beta_schedule = "linear"
+
if model_config is not None:
- beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule)
- self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
+ sampling_settings = model_config.sampling_settings
+ else:
+ sampling_settings = {}
+
+ beta_schedule = sampling_settings.get("beta_schedule", "linear")
+ linear_start = sampling_settings.get("linear_start", 0.00085)
+ linear_end = sampling_settings.get("linear_end", 0.012)
+
+ self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
self.sigma_data = 1.0
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py
index f2665565b..068cd6c95 100644
--- a/comfy/nodes/base_nodes.py
+++ b/comfy/nodes/base_nodes.py
@@ -6,7 +6,7 @@ import hashlib
import math
import random
-from PIL import Image, ImageOps
+from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo
import numpy as np
import safetensors.torch
@@ -930,8 +930,8 @@ class GLIGENTextBoxApply:
return (c, )
class EmptyLatentImage:
- def __init__(self, device="cpu"):
- self.device = device
+ def __init__(self):
+ self.device = comfy.model_management.intermediate_device()
@classmethod
def INPUT_TYPES(s):
@@ -944,7 +944,7 @@ class EmptyLatentImage:
CATEGORY = "latent"
def generate(self, width, height, batch_size=1):
- latent = torch.zeros([batch_size, 4, height // 8, width // 8])
+ latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
return ({"samples":latent}, )
@@ -1395,17 +1395,30 @@ class LoadImage:
FUNCTION = "load_image"
def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image)
- i = Image.open(image_path)
- i = ImageOps.exif_transpose(i)
- image = i.convert("RGB")
- image = np.array(image).astype(np.float32) / 255.0
- image = torch.from_numpy(image)[None,]
- if 'A' in i.getbands():
- mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
- mask = 1. - torch.from_numpy(mask)
+ img = Image.open(image_path)
+ output_images = []
+ output_masks = []
+ for i in ImageSequence.Iterator(img):
+ i = ImageOps.exif_transpose(i)
+ image = i.convert("RGB")
+ image = np.array(image).astype(np.float32) / 255.0
+ image = torch.from_numpy(image)[None,]
+ if 'A' in i.getbands():
+ mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
+ mask = 1. - torch.from_numpy(mask)
+ else:
+ mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
+ output_images.append(image)
+ output_masks.append(mask.unsqueeze(0))
+
+ if len(output_images) > 1:
+ output_image = torch.cat(output_images, dim=0)
+ output_mask = torch.cat(output_masks, dim=0)
else:
- mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
- return (image, mask.unsqueeze(0))
+ output_image = output_images[0]
+ output_mask = output_masks[0]
+
+ return (output_image, output_mask)
@classmethod
def IS_CHANGED(s, image):
@@ -1463,13 +1476,10 @@ class LoadImageMask:
return m.digest().hex()
@classmethod
- def VALIDATE_INPUTS(s, image, channel):
+ def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
- if channel not in s._color_channels:
- return "Invalid color channel: {}".format(channel)
-
return True
class ImageScale:
diff --git a/comfy/ops.py b/comfy/ops.py
index 0bfb698aa..f6f85de60 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -1,40 +1,115 @@
import torch
from contextlib import contextmanager
+import comfy.model_management
-class Linear(torch.nn.Linear):
- def reset_parameters(self):
- return None
+def cast_bias_weight(s, input):
+ bias = None
+ non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
+ if s.bias is not None:
+ bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
+ weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
+ return weight, bias
-class Conv2d(torch.nn.Conv2d):
- def reset_parameters(self):
- return None
-class Conv3d(torch.nn.Conv3d):
- def reset_parameters(self):
- return None
+class disable_weight_init:
+ class Linear(torch.nn.Linear):
+ comfy_cast_weights = False
+ def reset_parameters(self):
+ return None
-def conv_nd(dims, *args, **kwargs):
- if dims == 2:
- return Conv2d(*args, **kwargs)
- elif dims == 3:
- return Conv3d(*args, **kwargs)
- else:
- raise ValueError(f"unsupported dimensions: {dims}")
+ def forward_comfy_cast_weights(self, input):
+ weight, bias = cast_bias_weight(self, input)
+ return torch.nn.functional.linear(input, weight, bias)
-@contextmanager
-def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
- old_torch_nn_linear = torch.nn.Linear
- force_device = device
- force_dtype = dtype
- def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
- if force_device is not None:
- device = force_device
- if force_dtype is not None:
- dtype = force_dtype
- return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
+ def forward(self, *args, **kwargs):
+ if self.comfy_cast_weights:
+ return self.forward_comfy_cast_weights(*args, **kwargs)
+ else:
+ return super().forward(*args, **kwargs)
- torch.nn.Linear = linear_with_dtype
- try:
- yield
- finally:
- torch.nn.Linear = old_torch_nn_linear
+ class Conv2d(torch.nn.Conv2d):
+ comfy_cast_weights = False
+ def reset_parameters(self):
+ return None
+
+ def forward_comfy_cast_weights(self, input):
+ weight, bias = cast_bias_weight(self, input)
+ return self._conv_forward(input, weight, bias)
+
+ def forward(self, *args, **kwargs):
+ if self.comfy_cast_weights:
+ return self.forward_comfy_cast_weights(*args, **kwargs)
+ else:
+ return super().forward(*args, **kwargs)
+
+ class Conv3d(torch.nn.Conv3d):
+ comfy_cast_weights = False
+ def reset_parameters(self):
+ return None
+
+ def forward_comfy_cast_weights(self, input):
+ weight, bias = cast_bias_weight(self, input)
+ return self._conv_forward(input, weight, bias)
+
+ def forward(self, *args, **kwargs):
+ if self.comfy_cast_weights:
+ return self.forward_comfy_cast_weights(*args, **kwargs)
+ else:
+ return super().forward(*args, **kwargs)
+
+ class GroupNorm(torch.nn.GroupNorm):
+ comfy_cast_weights = False
+ def reset_parameters(self):
+ return None
+
+ def forward_comfy_cast_weights(self, input):
+ weight, bias = cast_bias_weight(self, input)
+ return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
+
+ def forward(self, *args, **kwargs):
+ if self.comfy_cast_weights:
+ return self.forward_comfy_cast_weights(*args, **kwargs)
+ else:
+ return super().forward(*args, **kwargs)
+
+
+ class LayerNorm(torch.nn.LayerNorm):
+ comfy_cast_weights = False
+ def reset_parameters(self):
+ return None
+
+ def forward_comfy_cast_weights(self, input):
+ weight, bias = cast_bias_weight(self, input)
+ return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
+
+ def forward(self, *args, **kwargs):
+ if self.comfy_cast_weights:
+ return self.forward_comfy_cast_weights(*args, **kwargs)
+ else:
+ return super().forward(*args, **kwargs)
+
+ @classmethod
+ def conv_nd(s, dims, *args, **kwargs):
+ if dims == 2:
+ return s.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return s.Conv3d(*args, **kwargs)
+ else:
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class manual_cast(disable_weight_init):
+ class Linear(disable_weight_init.Linear):
+ comfy_cast_weights = True
+
+ class Conv2d(disable_weight_init.Conv2d):
+ comfy_cast_weights = True
+
+ class Conv3d(disable_weight_init.Conv3d):
+ comfy_cast_weights = True
+
+ class GroupNorm(disable_weight_init.GroupNorm):
+ comfy_cast_weights = True
+
+ class LayerNorm(disable_weight_init.LayerNorm):
+ comfy_cast_weights = True
diff --git a/comfy/package_data_path_helper.py b/comfy/package_data_path_helper.py
new file mode 100644
index 000000000..d4dea80a4
--- /dev/null
+++ b/comfy/package_data_path_helper.py
@@ -0,0 +1,9 @@
+from importlib.resources import path
+import os
+
+
+def get_editable_resource_path(caller_file, *package_path):
+ filename = os.path.join(os.path.dirname(os.path.realpath(caller_file)), package_path[-1])
+ if not os.path.exists(filename):
+ filename = path(*package_path)
+ return filename
diff --git a/comfy/sample.py b/comfy/sample.py
index 197f2eb30..565ea662d 100644
--- a/comfy/sample.py
+++ b/comfy/sample.py
@@ -47,7 +47,8 @@ def convert_cond(cond):
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
- model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0])
+ model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) #TODO: remove
+ temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
out.append(temp)
return out
@@ -98,10 +99,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler = samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
- samples = samples.cpu()
+ samples = samples.to(model_management.intermediate_device())
cleanup_additional_models(models)
- cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
+ cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
@@ -111,8 +112,8 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent
sigmas = sigmas.to(model.load_device)
samples = samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
- samples = samples.cpu()
+ samples = samples.to(model_management.intermediate_device())
cleanup_additional_models(models)
- cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
+ cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples
diff --git a/comfy/samplers.py b/comfy/samplers.py
index 6ddadd893..eeac3fbd5 100644
--- a/comfy/samplers.py
+++ b/comfy/samplers.py
@@ -1,256 +1,264 @@
from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc
import torch
+import collections
from . import model_management
import math
+def get_area_and_mult(conds, x_in, timestep_in):
+ area = (x_in.shape[2], x_in.shape[3], 0, 0)
+ strength = 1.0
+
+ if 'timestep_start' in conds:
+ timestep_start = conds['timestep_start']
+ if timestep_in[0] > timestep_start:
+ return None
+ if 'timestep_end' in conds:
+ timestep_end = conds['timestep_end']
+ if timestep_in[0] < timestep_end:
+ return None
+ if 'area' in conds:
+ area = conds['area']
+ if 'strength' in conds:
+ strength = conds['strength']
+
+ input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
+ if 'mask' in conds:
+ # Scale the mask to the size of the input
+ # The mask should have been resized as we began the sampling process
+ mask_strength = 1.0
+ if "mask_strength" in conds:
+ mask_strength = conds["mask_strength"]
+ mask = conds['mask']
+ assert(mask.shape[1] == x_in.shape[2])
+ assert(mask.shape[2] == x_in.shape[3])
+ mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
+ mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
+ else:
+ mask = torch.ones_like(input_x)
+ mult = mask * strength
+
+ if 'mask' not in conds:
+ rr = 8
+ if area[2] != 0:
+ for t in range(rr):
+ mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
+ if (area[0] + area[2]) < x_in.shape[2]:
+ for t in range(rr):
+ mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
+ if area[3] != 0:
+ for t in range(rr):
+ mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
+ if (area[1] + area[3]) < x_in.shape[3]:
+ for t in range(rr):
+ mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
+
+ conditioning = {}
+ model_conds = conds["model_conds"]
+ for c in model_conds:
+ conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
+
+ control = conds.get('control', None)
+
+ patches = None
+ if 'gligen' in conds:
+ gligen = conds['gligen']
+ patches = {}
+ gligen_type = gligen[0]
+ gligen_model = gligen[1]
+ if gligen_type == "position":
+ gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
+ else:
+ gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
+
+ patches['middle_patch'] = [gligen_patch]
+
+ cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches'])
+ return cond_obj(input_x, mult, conditioning, area, control, patches)
+
+def cond_equal_size(c1, c2):
+ if c1 is c2:
+ return True
+ if c1.keys() != c2.keys():
+ return False
+ for k in c1:
+ if not c1[k].can_concat(c2[k]):
+ return False
+ return True
+
+def can_concat_cond(c1, c2):
+ if c1.input_x.shape != c2.input_x.shape:
+ return False
+
+ def objects_concatable(obj1, obj2):
+ if (obj1 is None) != (obj2 is None):
+ return False
+ if obj1 is not None:
+ if obj1 is not obj2:
+ return False
+ return True
+
+ if not objects_concatable(c1.control, c2.control):
+ return False
+
+ if not objects_concatable(c1.patches, c2.patches):
+ return False
+
+ return cond_equal_size(c1.conditioning, c2.conditioning)
+
+def cond_cat(c_list):
+ c_crossattn = []
+ c_concat = []
+ c_adm = []
+ crossattn_max_len = 0
+
+ temp = {}
+ for x in c_list:
+ for k in x:
+ cur = temp.get(k, [])
+ cur.append(x[k])
+ temp[k] = cur
+
+ out = {}
+ for k in temp:
+ conds = temp[k]
+ out[k] = conds[0].concat(conds[1:])
+
+ return out
+
+def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
+ out_cond = torch.zeros_like(x_in)
+ out_count = torch.ones_like(x_in) * 1e-37
+
+ out_uncond = torch.zeros_like(x_in)
+ out_uncond_count = torch.ones_like(x_in) * 1e-37
+
+ COND = 0
+ UNCOND = 1
+
+ to_run = []
+ for x in cond:
+ p = get_area_and_mult(x, x_in, timestep)
+ if p is None:
+ continue
+
+ to_run += [(p, COND)]
+ if uncond is not None:
+ for x in uncond:
+ p = get_area_and_mult(x, x_in, timestep)
+ if p is None:
+ continue
+
+ to_run += [(p, UNCOND)]
+
+ while len(to_run) > 0:
+ first = to_run[0]
+ first_shape = first[0][0].shape
+ to_batch_temp = []
+ for x in range(len(to_run)):
+ if can_concat_cond(to_run[x][0], first[0]):
+ to_batch_temp += [x]
+
+ to_batch_temp.reverse()
+ to_batch = to_batch_temp[:1]
+
+ free_memory = model_management.get_free_memory(x_in.device)
+ for i in range(1, len(to_batch_temp) + 1):
+ batch_amount = to_batch_temp[:len(to_batch_temp)//i]
+ input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
+ if model.memory_required(input_shape) < free_memory:
+ to_batch = batch_amount
+ break
+
+ input_x = []
+ mult = []
+ c = []
+ cond_or_uncond = []
+ area = []
+ control = None
+ patches = None
+ for x in to_batch:
+ o = to_run.pop(x)
+ p = o[0]
+ input_x.append(p.input_x)
+ mult.append(p.mult)
+ c.append(p.conditioning)
+ area.append(p.area)
+ cond_or_uncond.append(o[1])
+ control = p.control
+ patches = p.patches
+
+ batch_chunks = len(cond_or_uncond)
+ input_x = torch.cat(input_x)
+ c = cond_cat(c)
+ timestep_ = torch.cat([timestep] * batch_chunks)
+
+ if control is not None:
+ c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
+
+ transformer_options = {}
+ if 'transformer_options' in model_options:
+ transformer_options = model_options['transformer_options'].copy()
+
+ if patches is not None:
+ if "patches" in transformer_options:
+ cur_patches = transformer_options["patches"].copy()
+ for p in patches:
+ if p in cur_patches:
+ cur_patches[p] = cur_patches[p] + patches[p]
+ else:
+ cur_patches[p] = patches[p]
+ else:
+ transformer_options["patches"] = patches
+
+ transformer_options["cond_or_uncond"] = cond_or_uncond[:]
+ transformer_options["sigmas"] = timestep
+
+ c['transformer_options'] = transformer_options
+
+ if 'model_function_wrapper' in model_options:
+ output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
+ else:
+ output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
+ del input_x
+
+ for o in range(batch_chunks):
+ if cond_or_uncond[o] == COND:
+ out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
+ out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
+ else:
+ out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
+ out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
+ del mult
+
+ out_cond /= out_count
+ del out_count
+ out_uncond /= out_uncond_count
+ del out_uncond_count
+ return out_cond, out_uncond
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
- def get_area_and_mult(conds, x_in, timestep_in):
- area = (x_in.shape[2], x_in.shape[3], 0, 0)
- strength = 1.0
-
- if 'timestep_start' in conds:
- timestep_start = conds['timestep_start']
- if timestep_in[0] > timestep_start:
- return None
- if 'timestep_end' in conds:
- timestep_end = conds['timestep_end']
- if timestep_in[0] < timestep_end:
- return None
- if 'area' in conds:
- area = conds['area']
- if 'strength' in conds:
- strength = conds['strength']
-
- input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
- if 'mask' in conds:
- # Scale the mask to the size of the input
- # The mask should have been resized as we began the sampling process
- mask_strength = 1.0
- if "mask_strength" in conds:
- mask_strength = conds["mask_strength"]
- mask = conds['mask']
- assert(mask.shape[1] == x_in.shape[2])
- assert(mask.shape[2] == x_in.shape[3])
- mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
- mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
- else:
- mask = torch.ones_like(input_x)
- mult = mask * strength
-
- if 'mask' not in conds:
- rr = 8
- if area[2] != 0:
- for t in range(rr):
- mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
- if (area[0] + area[2]) < x_in.shape[2]:
- for t in range(rr):
- mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
- if area[3] != 0:
- for t in range(rr):
- mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
- if (area[1] + area[3]) < x_in.shape[3]:
- for t in range(rr):
- mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
-
- conditionning = {}
- model_conds = conds["model_conds"]
- for c in model_conds:
- conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
-
- control = None
- if 'control' in conds:
- control = conds['control']
-
- patches = None
- if 'gligen' in conds:
- gligen = conds['gligen']
- patches = {}
- gligen_type = gligen[0]
- gligen_model = gligen[1]
- if gligen_type == "position":
- gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
- else:
- gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
-
- patches['middle_patch'] = [gligen_patch]
-
- return (input_x, mult, conditionning, area, control, patches)
-
- def cond_equal_size(c1, c2):
- if c1 is c2:
- return True
- if c1.keys() != c2.keys():
- return False
- for k in c1:
- if not c1[k].can_concat(c2[k]):
- return False
- return True
-
- def can_concat_cond(c1, c2):
- if c1[0].shape != c2[0].shape:
- return False
-
- #control
- if (c1[4] is None) != (c2[4] is None):
- return False
- if c1[4] is not None:
- if c1[4] is not c2[4]:
- return False
-
- #patches
- if (c1[5] is None) != (c2[5] is None):
- return False
- if (c1[5] is not None):
- if c1[5] is not c2[5]:
- return False
-
- return cond_equal_size(c1[2], c2[2])
-
- def cond_cat(c_list):
- c_crossattn = []
- c_concat = []
- c_adm = []
- crossattn_max_len = 0
-
- temp = {}
- for x in c_list:
- for k in x:
- cur = temp.get(k, [])
- cur.append(x[k])
- temp[k] = cur
-
- out = {}
- for k in temp:
- conds = temp[k]
- out[k] = conds[0].concat(conds[1:])
-
- return out
-
- def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
- out_cond = torch.zeros_like(x_in)
- out_count = torch.ones_like(x_in) * 1e-37
-
- out_uncond = torch.zeros_like(x_in)
- out_uncond_count = torch.ones_like(x_in) * 1e-37
-
- COND = 0
- UNCOND = 1
-
- to_run = []
- for x in cond:
- p = get_area_and_mult(x, x_in, timestep)
- if p is None:
- continue
-
- to_run += [(p, COND)]
- if uncond is not None:
- for x in uncond:
- p = get_area_and_mult(x, x_in, timestep)
- if p is None:
- continue
-
- to_run += [(p, UNCOND)]
-
- while len(to_run) > 0:
- first = to_run[0]
- first_shape = first[0][0].shape
- to_batch_temp = []
- for x in range(len(to_run)):
- if can_concat_cond(to_run[x][0], first[0]):
- to_batch_temp += [x]
-
- to_batch_temp.reverse()
- to_batch = to_batch_temp[:1]
-
- free_memory = model_management.get_free_memory(x_in.device)
- for i in range(1, len(to_batch_temp) + 1):
- batch_amount = to_batch_temp[:len(to_batch_temp)//i]
- input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
- if model.memory_required(input_shape) < free_memory:
- to_batch = batch_amount
- break
-
- input_x = []
- mult = []
- c = []
- cond_or_uncond = []
- area = []
- control = None
- patches = None
- for x in to_batch:
- o = to_run.pop(x)
- p = o[0]
- input_x += [p[0]]
- mult += [p[1]]
- c += [p[2]]
- area += [p[3]]
- cond_or_uncond += [o[1]]
- control = p[4]
- patches = p[5]
-
- batch_chunks = len(cond_or_uncond)
- input_x = torch.cat(input_x)
- c = cond_cat(c)
- timestep_ = torch.cat([timestep] * batch_chunks)
-
- if control is not None:
- c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
-
- transformer_options = {}
- if 'transformer_options' in model_options:
- transformer_options = model_options['transformer_options'].copy()
-
- if patches is not None:
- if "patches" in transformer_options:
- cur_patches = transformer_options["patches"].copy()
- for p in patches:
- if p in cur_patches:
- cur_patches[p] = cur_patches[p] + patches[p]
- else:
- cur_patches[p] = patches[p]
- else:
- transformer_options["patches"] = patches
-
- transformer_options["cond_or_uncond"] = cond_or_uncond[:]
- transformer_options["sigmas"] = timestep
-
- c['transformer_options'] = transformer_options
-
- if 'model_function_wrapper' in model_options:
- output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
- else:
- output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
- del input_x
-
- for o in range(batch_chunks):
- if cond_or_uncond[o] == COND:
- out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
- out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
- else:
- out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
- out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
- del mult
-
- out_cond /= out_count
- del out_count
- out_uncond /= out_uncond_count
- del out_uncond_count
- return out_cond, out_uncond
-
-
- if math.isclose(cond_scale, 1.0):
- uncond = None
-
- cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
- if "sampler_cfg_function" in model_options:
- args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
- return x - model_options["sampler_cfg_function"](args)
+ if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
+ uncond_ = None
else:
- return uncond + (cond - uncond) * cond_scale
+ uncond_ = uncond
+
+ cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
+ if "sampler_cfg_function" in model_options:
+ args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
+ "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
+ cfg_result = x - model_options["sampler_cfg_function"](args)
+ else:
+ cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
+
+ for fn in model_options.get("sampler_post_cfg_function", []):
+ args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
+ "sigma": timestep, "model_options": model_options, "input": x}
+ cfg_result = fn(args)
+
+ return cfg_result
class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model):
@@ -272,10 +280,7 @@ class KSamplerX0Inpaint(torch.nn.Module):
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
if denoise_mask is not None:
- out *= denoise_mask
-
- if denoise_mask is not None:
- out += self.latent_image * latent_mask
+ out = out * denoise_mask + self.latent_image * latent_mask
return out
def simple_scheduler(model, steps):
@@ -590,6 +595,13 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model, positive)
+ if latent_image is not None:
+ latent_image = model.process_latent_in(latent_image)
+
+ if hasattr(model, 'extra_conds'):
+ positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
+ negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
+
#make sure each cond area has an opposite one with the same area
for c in positive:
create_cond_with_same_area_if_none(negative, c)
@@ -601,13 +613,6 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
- if latent_image is not None:
- latent_image = model.process_latent_in(latent_image)
-
- if hasattr(model, 'extra_conds'):
- positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
- negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
-
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
@@ -630,7 +635,7 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model, steps, sgm=True)
else:
- print("error invalid scheduler", self.scheduler)
+ print("error invalid scheduler", scheduler_name)
return sigmas
def sampler_object(name):
diff --git a/comfy/sd.py b/comfy/sd.py
index cadce719e..c51172a46 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -148,12 +148,14 @@ class CLIP:
return self.patcher.get_key_patches()
class VAE:
- def __init__(self, sd=None, device=None, config=None):
+ def __init__(self, sd=None, device=None, config=None, dtype=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
+ self.downscale_ratio = 8
+ self.latent_channels = 4
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
@@ -169,6 +171,11 @@ class VAE:
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
+
+ if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
+ ddconfig['ch_mult'] = [1, 2, 4]
+ self.downscale_ratio = 4
+
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else:
self.first_stage_model = AutoencoderKL(**(config['params']))
@@ -185,8 +192,11 @@ class VAE:
device = model_management.vae_device()
self.device = device
offload_device = model_management.vae_offload_device()
- self.vae_dtype = model_management.vae_dtype()
+ if dtype is None:
+ dtype = model_management.vae_dtype()
+ self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype)
+ self.output_device = model_management.intermediate_device()
self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
@@ -198,9 +208,9 @@ class VAE:
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
output = torch.clamp((
- (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
- utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
- utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar))
+ (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
+ utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
+ utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar))
/ 3.0) / 2.0, min=0.0, max=1.0)
return output
@@ -211,9 +221,9 @@ class VAE:
pbar = utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
- samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
- samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
- samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
+ samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
+ samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
+ samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples /= 3.0
return samples
@@ -225,15 +235,15 @@ class VAE:
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
- pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
+ pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
- pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0)
+ pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0)
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)
- pixel_samples = pixel_samples.cpu().movedim(1,-1)
+ pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
@@ -249,10 +259,10 @@ class VAE:
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
- samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
+ samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
- samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
+ samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
@@ -429,11 +439,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
parameters = utils.calculate_parameters(sd, "model.diffusion_model.")
unet_dtype = model_management.unet_dtype(model_params=parameters)
+ load_device = model_management.get_torch_device()
+ manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
class WeightsLoader(torch.nn.Module):
pass
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
+ model_config.set_manual_cast(manual_cast_dtype)
+
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
@@ -466,7 +480,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
print("left over keys:", left_over)
if output_model:
- _model_patcher = model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
+ _model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"):
print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)
@@ -477,6 +491,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def load_unet_state_dict(sd): #load unet in diffusers format
parameters = utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters)
+ load_device = model_management.get_torch_device()
+ manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
+
if "input_blocks.0.0.weight" in sd: #ldm
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
if model_config is None:
@@ -497,13 +514,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format
else:
print(diffusers_keys[k], k)
offload_device = model_management.unet_offload_device()
+ model_config.set_manual_cast(manual_cast_dtype)
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:
print("left over keys in unet:", left_over)
- return model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
+ return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
def load_unet(unet_path):
sd = utils.load_torch_file(unet_path)
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index 0ea9611dd..6722eb83f 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -1,6 +1,6 @@
import os
-from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils
+from transformers import CLIPTokenizer
from . import ops
import torch
import traceback
@@ -8,6 +8,8 @@ import zipfile
from . import model_management
from pkg_resources import resource_filename
import contextlib
+from . import clip_model
+import json
def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None)
@@ -38,7 +40,7 @@ class ClipTokenWeightEncoder:
out, pooled = self.encode(to_encode)
if pooled is not None:
- first_pooled = pooled[0:1].cpu()
+ first_pooled = pooled[0:1].to(model_management.intermediate_device())
else:
first_pooled = pooled
@@ -55,8 +57,8 @@ class ClipTokenWeightEncoder:
output.append(z)
if (len(output) == 0):
- return out[-1:].cpu(), first_pooled
- return torch.cat(output, dim=-2).cpu(), first_pooled
+ return out[-1:].to(model_management.intermediate_device()), first_pooled
+ return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
@@ -66,33 +68,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
- freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None,
- special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig,
- model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
+ freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=clip_model.CLIPTextModel,
+ special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
- self.num_layers = 12
- if textmodel_path is not None:
- self.transformer = model_class.from_pretrained(textmodel_path)
- else:
- if textmodel_json_config is None:
- textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
- if not os.path.exists(textmodel_json_config):
- textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json')
- config = config_class.from_json_file(textmodel_json_config)
- self.num_layers = config.num_hidden_layers
- with ops.use_comfy_ops(device, dtype):
- with modeling_utils.no_init_weights():
- self.transformer = model_class(config)
- self.inner_name = inner_name
- if dtype is not None:
- self.transformer.to(dtype)
- inner_model = getattr(self.transformer, self.inner_name)
- if hasattr(inner_model, "embeddings"):
- inner_model.embeddings.to(torch.float32)
- else:
- self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
+ if textmodel_json_config is None:
+ textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
+ if not os.path.exists(textmodel_json_config):
+ textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json')
+
+ with open(textmodel_json_config) as f:
+ config = json.load(f)
+
+ self.transformer = model_class(config, dtype, device, ops.manual_cast)
+ self.num_layers = self.transformer.num_layers
self.max_length = max_length
if freeze:
@@ -107,7 +97,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer_norm_hidden_state = layer_norm_hidden_state
if layer == "hidden":
assert layer_idx is not None
- assert abs(layer_idx) <= self.num_layers
+ assert abs(layer_idx) < self.num_layers
self.clip_layer(layer_idx)
self.layer_default = (self.layer, self.layer_idx)
@@ -118,7 +108,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
param.requires_grad = False
def clip_layer(self, layer_idx):
- if abs(layer_idx) >= self.num_layers:
+ if abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
self.layer = "hidden"
@@ -173,41 +163,31 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
- if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
- precision_scope = torch.autocast
+ attention_mask = None
+ if self.enable_attention_masks:
+ attention_mask = torch.zeros_like(tokens)
+ max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
+ for x in range(attention_mask.shape[0]):
+ for y in range(attention_mask.shape[1]):
+ attention_mask[x, y] = 1
+ if tokens[x, y] == max_token:
+ break
+
+ outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
+ self.transformer.set_input_embeddings(backup_embeds)
+
+ if self.layer == "last":
+ z = outputs[0]
else:
- precision_scope = lambda a, dtype: contextlib.nullcontext(a)
+ z = outputs[1]
- with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
- attention_mask = None
- if self.enable_attention_masks:
- attention_mask = torch.zeros_like(tokens)
- max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
- for x in range(attention_mask.shape[0]):
- for y in range(attention_mask.shape[1]):
- attention_mask[x, y] = 1
- if tokens[x, y] == max_token:
- break
+ if outputs[2] is not None:
+ pooled_output = outputs[2].float()
+ else:
+ pooled_output = None
- outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden")
- self.transformer.set_input_embeddings(backup_embeds)
-
- if self.layer == "last":
- z = outputs.last_hidden_state
- elif self.layer == "pooled":
- z = outputs.pooler_output[:, None, :]
- else:
- z = outputs.hidden_states[self.layer_idx]
- if self.layer_norm_hidden_state:
- z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
-
- if hasattr(outputs, "pooler_output"):
- pooled_output = outputs.pooler_output.float()
- else:
- pooled_output = None
-
- if self.text_projection is not None and pooled_output is not None:
- pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
+ if self.text_projection is not None and pooled_output is not None:
+ pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output
def encode(self, tokens):
diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py
index a33ce4210..db22e354d 100644
--- a/comfy/sd2_clip.py
+++ b/comfy/sd2_clip.py
@@ -4,15 +4,15 @@ from . import sd1_clip
import os
class SD2ClipHModel(sd1_clip.SDClipModel):
- def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
+ def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
if layer == "penultimate":
layer="hidden"
- layer_idx=23
+ layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
if not os.path.exists(textmodel_json_config):
textmodel_json_config = resource_filename('comfy', 'sd2_clip_config.json')
- super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
+ super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py
index 0a3d0530f..2d7a9e4e9 100644
--- a/comfy/sdxl_clip.py
+++ b/comfy/sdxl_clip.py
@@ -3,13 +3,13 @@ import torch
import os
class SDXLClipG(sd1_clip.SDClipModel):
- def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
+ def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
if layer == "penultimate":
layer="hidden"
layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
- super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype,
+ super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
def load_sd(self, sd):
@@ -37,7 +37,7 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
super().__init__()
- self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False)
+ self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx):
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index 455323b96..1d442d4dd 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -217,6 +217,16 @@ class SSD1B(SDXL):
"use_temporal_attention": False,
}
+class Segmind_Vega(SDXL):
+ unet_config = {
+ "model_channels": 320,
+ "use_linear_in_transformer": True,
+ "transformer_depth": [0, 0, 1, 1, 2, 2],
+ "context_dim": 2048,
+ "adm_in_channels": 2816,
+ "use_temporal_attention": False,
+ }
+
class SVD_img2vid(supported_models_base.BASE):
unet_config = {
"model_channels": 320,
@@ -242,5 +252,59 @@ class SVD_img2vid(supported_models_base.BASE):
def clip_target(self):
return None
-models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]
+class Stable_Zero123(supported_models_base.BASE):
+ unet_config = {
+ "context_dim": 768,
+ "model_channels": 320,
+ "use_linear_in_transformer": False,
+ "adm_in_channels": None,
+ "use_temporal_attention": False,
+ "in_channels": 8,
+ }
+
+ unet_extra_config = {
+ "num_heads": 8,
+ "num_head_channels": -1,
+ }
+
+ clip_vision_prefix = "cond_stage_model.model.visual."
+
+ latent_format = latent_formats.SD15
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
+ return out
+
+ def clip_target(self):
+ return None
+
+class SD_X4Upscaler(SD20):
+ unet_config = {
+ "context_dim": 1024,
+ "model_channels": 256,
+ 'in_channels': 7,
+ "use_linear_in_transformer": True,
+ "adm_in_channels": None,
+ "use_temporal_attention": False,
+ }
+
+ unet_extra_config = {
+ "disable_self_attentions": [True, True, True, False],
+ "num_classes": 1000,
+ "num_heads": 8,
+ "num_head_channels": -1,
+ }
+
+ latent_format = latent_formats.SD_X4
+
+ sampling_settings = {
+ "linear_start": 0.0001,
+ "linear_end": 0.02,
+ }
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.SD_X4Upscaler(self, device=device)
+ return out
+
+models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler]
models += [SVD_img2vid]
diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py
index 3412cfea0..49087d23e 100644
--- a/comfy/supported_models_base.py
+++ b/comfy/supported_models_base.py
@@ -22,6 +22,8 @@ class BASE:
sampling_settings = {}
latent_format = latent_formats.LatentFormat
+ manual_cast_dtype = None
+
@classmethod
def matches(s, unet_config):
for k in s.unet_config:
@@ -71,3 +73,5 @@ class BASE:
replace_prefix = {"": "first_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
+ def set_manual_cast(self, manual_cast_dtype):
+ self.manual_cast_dtype = manual_cast_dtype
diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py
index a91da2f91..856f4f2cd 100644
--- a/comfy/taesd/taesd.py
+++ b/comfy/taesd/taesd.py
@@ -7,9 +7,10 @@ import torch
import torch.nn as nn
from .. import utils
+from .. import ops
def conv(n_in, n_out, **kwargs):
- return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
+ return ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
@@ -19,7 +20,7 @@ class Block(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
- self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
+ self.skip = ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU()
def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x))
diff --git a/comfy/utils.py b/comfy/utils.py
index 0aa655264..8b71ed387 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -378,7 +378,7 @@ def lanczos(samples, width, height):
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
result = torch.stack(images)
- return result
+ return result.to(samples.device, samples.dtype)
def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center":
@@ -407,17 +407,17 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
@torch.inference_mode()
-def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None):
- output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu")
+def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
+ output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device)
for b in range(samples.shape[0]):
s = samples[b:b+1]
- out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
- out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
+ out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
+ out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
for y in range(0, s.shape[2], tile_y - overlap):
for x in range(0, s.shape[3], tile_x - overlap):
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
- ps = function(s_in).cpu()
+ ps = function(s_in).to(output_device)
mask = torch.ones_like(ps)
feather = round(overlap * upscale_amount)
for t in range(feather):
diff --git a/comfy_extras/nodes/nodes_canny.py b/comfy_extras/nodes/nodes_canny.py
index 94d453f2c..730dded5f 100644
--- a/comfy_extras/nodes/nodes_canny.py
+++ b/comfy_extras/nodes/nodes_canny.py
@@ -291,7 +291,7 @@ class Canny:
def detect_edge(self, image, low_threshold, high_threshold):
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
- img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1)
+ img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
return (img_out,)
NODE_CLASS_MAPPINGS = {
diff --git a/comfy_extras/nodes/nodes_custom_sampler.py b/comfy_extras/nodes/nodes_custom_sampler.py
index 04983f5be..baf72ecc5 100644
--- a/comfy_extras/nodes/nodes_custom_sampler.py
+++ b/comfy_extras/nodes/nodes_custom_sampler.py
@@ -1,9 +1,9 @@
-import comfy.samplers
-import comfy.sample
+from comfy import samplers
+from comfy import sample
from comfy.k_diffusion import sampling as k_diffusion_sampling
from comfy.cmd import latent_preview
import torch
-import comfy.utils
+from comfy import utils
class BasicScheduler:
@@ -11,8 +11,9 @@ class BasicScheduler:
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
- "scheduler": (comfy.samplers.SCHEDULER_NAMES, ),
+ "scheduler": (samplers.SCHEDULER_NAMES, ),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
@@ -20,8 +21,15 @@ class BasicScheduler:
FUNCTION = "get_sigmas"
- def get_sigmas(self, model, scheduler, steps):
- sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, steps).cpu()
+ def get_sigmas(self, model, scheduler, steps, denoise):
+ total_steps = steps
+ if denoise < 1.0:
+ total_steps = int(steps/denoise)
+
+ inner_model = model.patch_model(patch_weights=False)
+ sigmas = samplers.calculate_sigmas_scheduler(inner_model, scheduler, total_steps).cpu()
+ model.unpatch_model()
+ sigmas = sigmas[-(steps + 1):]
return (sigmas, )
@@ -87,6 +95,7 @@ class SDTurboScheduler:
return {"required":
{"model": ("MODEL",),
"steps": ("INT", {"default": 1, "min": 1, "max": 10}),
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
@@ -94,9 +103,12 @@ class SDTurboScheduler:
FUNCTION = "get_sigmas"
- def get_sigmas(self, model, steps):
- timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps]
- sigmas = model.model.model_sampling.sigma(timesteps)
+ def get_sigmas(self, model, steps, denoise):
+ start_step = 10 - int(10 * denoise)
+ timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
+ inner_model = model.patch_model(patch_weights=False)
+ sigmas = inner_model.model_sampling.sigma(timesteps)
+ model.unpatch_model()
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return (sigmas, )
@@ -159,7 +171,7 @@ class KSamplerSelect:
@classmethod
def INPUT_TYPES(s):
return {"required":
- {"sampler_name": (comfy.samplers.SAMPLER_NAMES, ),
+ {"sampler_name": (samplers.SAMPLER_NAMES, ),
}
}
RETURN_TYPES = ("SAMPLER",)
@@ -168,7 +180,7 @@ class KSamplerSelect:
FUNCTION = "get_sampler"
def get_sampler(self, sampler_name):
- sampler = comfy.samplers.sampler_object(sampler_name)
+ sampler = samplers.sampler_object(sampler_name)
return (sampler, )
class SamplerDPMPP_2M_SDE:
@@ -191,7 +203,7 @@ class SamplerDPMPP_2M_SDE:
sampler_name = "dpmpp_2m_sde"
else:
sampler_name = "dpmpp_2m_sde_gpu"
- sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
+ sampler = samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
return (sampler, )
@@ -215,7 +227,7 @@ class SamplerDPMPP_SDE:
sampler_name = "dpmpp_sde"
else:
sampler_name = "dpmpp_sde_gpu"
- sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
+ sampler = samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
return (sampler, )
class SamplerCustom:
@@ -248,7 +260,7 @@ class SamplerCustom:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else:
batch_inds = latent["batch_index"] if "batch_index" in latent else None
- noise = comfy.sample.prepare_noise(latent_image, noise_seed, batch_inds)
+ noise = sample.prepare_noise(latent_image, noise_seed, batch_inds)
noise_mask = None
if "noise_mask" in latent:
@@ -257,8 +269,8 @@ class SamplerCustom:
x0_output = {}
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
- disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
- samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
+ disable_pbar = not utils.PROGRESS_BAR_ENABLED
+ samples = sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
out = latent.copy()
out["samples"] = samples
diff --git a/comfy_extras/nodes/nodes_hypertile.py b/comfy_extras/nodes/nodes_hypertile.py
index 0d7d4c954..e7446b2e5 100644
--- a/comfy_extras/nodes/nodes_hypertile.py
+++ b/comfy_extras/nodes/nodes_hypertile.py
@@ -2,9 +2,10 @@
import math
from einops import rearrange
-import random
+# Use torch rng for consistency across generations
+from torch import randint
-def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int:
+def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
min_value = min(min_value, value)
# All big divisors of value (inclusive)
@@ -12,8 +13,10 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter
ns = [value // i for i in divisors[:max_options]] # has at least 1 element
- random.seed(counter)
- idx = random.randint(0, len(ns) - 1)
+ if len(ns) - 1 > 0:
+ idx = randint(low=0, high=len(ns) - 1, size=(1,)).item()
+ else:
+ idx = 0
return ns[idx]
@@ -42,7 +45,6 @@ class HyperTile:
latent_tile_size = max(32, tile_size) // 8
self.temp = None
- self.counter = 1
def hypertile_in(q, k, v, extra_options):
if q.shape[-1] in apply_to:
@@ -53,10 +55,8 @@ class HyperTile:
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1
- nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter)
- self.counter += 1
- nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter)
- self.counter += 1
+ nh = random_divisor(h, latent_tile_size * factor, swap_size)
+ nw = random_divisor(w, latent_tile_size * factor, swap_size)
if nh * nw > 1:
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
diff --git a/comfy_extras/nodes/nodes_images.py b/comfy_extras/nodes/nodes_images.py
index d82e0e250..b6202f181 100644
--- a/comfy_extras/nodes/nodes_images.py
+++ b/comfy_extras/nodes/nodes_images.py
@@ -73,7 +73,7 @@ class SaveAnimatedWEBP:
OUTPUT_NODE = True
- CATEGORY = "_for_testing"
+ CATEGORY = "image/animation"
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
method = self.methods.get(method)
@@ -135,7 +135,7 @@ class SaveAnimatedPNG:
OUTPUT_NODE = True
- CATEGORY = "_for_testing"
+ CATEGORY = "image/animation"
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
diff --git a/comfy_extras/nodes/nodes_latent.py b/comfy_extras/nodes/nodes_latent.py
index cedf39d63..2eefc4c55 100644
--- a/comfy_extras/nodes/nodes_latent.py
+++ b/comfy_extras/nodes/nodes_latent.py
@@ -3,9 +3,7 @@ import torch
def reshape_latent_to(target_shape, latent):
if latent.shape[1:] != target_shape[1:]:
- latent.movedim(1, -1)
latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center")
- latent.movedim(-1, 1)
return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
@@ -102,9 +100,32 @@ class LatentInterpolate:
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,)
+class LatentBatch:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
+
+ RETURN_TYPES = ("LATENT",)
+ FUNCTION = "batch"
+
+ CATEGORY = "latent/batch"
+
+ def batch(self, samples1, samples2):
+ samples_out = samples1.copy()
+ s1 = samples1["samples"]
+ s2 = samples2["samples"]
+
+ if s1.shape[1:] != s2.shape[1:]:
+ s2 = comfy.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center")
+ s = torch.cat((s1, s2), dim=0)
+ samples_out["samples"] = s
+ samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
+ return (samples_out,)
+
NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate,
+ "LatentBatch": LatentBatch,
}
diff --git a/comfy_extras/nodes/nodes_mask.py b/comfy_extras/nodes/nodes_mask.py
index ae0f9d3b5..1ed455f04 100644
--- a/comfy_extras/nodes/nodes_mask.py
+++ b/comfy_extras/nodes/nodes_mask.py
@@ -7,6 +7,7 @@ from comfy.nodes.common import MAX_RESOLUTION
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
+ source = source.to(destination.device)
if resize_source:
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
@@ -21,7 +22,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
if mask is None:
mask = torch.ones_like(source)
else:
- mask = mask.clone()
+ mask = mask.to(destination.device, copy=True)
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])
diff --git a/comfy_extras/nodes/nodes_model_advanced.py b/comfy_extras/nodes/nodes_model_advanced.py
index 6b95831bc..bbece153a 100644
--- a/comfy_extras/nodes/nodes_model_advanced.py
+++ b/comfy_extras/nodes/nodes_model_advanced.py
@@ -16,41 +16,19 @@ class LCM(comfy.model_sampling.EPS):
return c_out * x0 + c_skip * model_input
-class ModelSamplingDiscreteDistilled(torch.nn.Module):
+class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
original_timesteps = 50
- def __init__(self):
- super().__init__()
- self.sigma_data = 1.0
- timesteps = 1000
- beta_start = 0.00085
- beta_end = 0.012
+ def __init__(self, model_config=None):
+ super().__init__(model_config)
- betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
- alphas = 1.0 - betas
- alphas_cumprod = torch.cumprod(alphas, dim=0)
+ self.skip_steps = self.num_timesteps // self.original_timesteps
- self.skip_steps = timesteps // self.original_timesteps
-
-
- alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
+ sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
for x in range(self.original_timesteps):
- alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
+ sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps]
- sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
- self.set_sigmas(sigmas)
-
- def set_sigmas(self, sigmas):
- self.register_buffer('sigmas', sigmas)
- self.register_buffer('log_sigmas', sigmas.log())
-
- @property
- def sigma_min(self):
- return self.sigmas[0]
-
- @property
- def sigma_max(self):
- return self.sigmas[-1]
+ self.set_sigmas(sigmas_valid)
def timestep(self, sigma):
log_sigma = sigma.log()
@@ -65,14 +43,6 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module):
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp().to(timestep.device)
- def percent_to_sigma(self, percent):
- if percent <= 0.0:
- return 999999999.9
- if percent >= 1.0:
- return 0.0
- percent = 1.0 - percent
- return self.sigma(torch.tensor(percent * 999.0)).item()
-
def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
@@ -121,7 +91,7 @@ class ModelSamplingDiscrete:
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
- model_sampling = ModelSamplingAdvanced()
+ model_sampling = ModelSamplingAdvanced(model.model.model_config)
if zsnr:
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
@@ -153,7 +123,7 @@ class ModelSamplingContinuousEDM:
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
pass
- model_sampling = ModelSamplingAdvanced()
+ model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_sigma_range(sigma_min, sigma_max)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
diff --git a/comfy_extras/nodes/nodes_perpneg.py b/comfy_extras/nodes/nodes_perpneg.py
new file mode 100644
index 000000000..875937530
--- /dev/null
+++ b/comfy_extras/nodes/nodes_perpneg.py
@@ -0,0 +1,53 @@
+import torch
+from comfy import sample
+from comfy import samplers
+
+
+class PerpNeg:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {"model": ("MODEL", ),
+ "empty_conditioning": ("CONDITIONING", ),
+ "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}),
+ }}
+ RETURN_TYPES = ("MODEL",)
+ FUNCTION = "patch"
+
+ CATEGORY = "_for_testing"
+
+ def patch(self, model, empty_conditioning, neg_scale):
+ m = model.clone()
+ nocond = sample.convert_cond(empty_conditioning)
+
+ def cfg_function(args):
+ model = args["model"]
+ noise_pred_pos = args["cond_denoised"]
+ noise_pred_neg = args["uncond_denoised"]
+ cond_scale = args["cond_scale"]
+ x = args["input"]
+ sigma = args["sigma"]
+ model_options = args["model_options"]
+ nocond_processed = samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
+
+ (noise_pred_nocond, _) = samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options)
+
+ pos = noise_pred_pos - noise_pred_nocond
+ neg = noise_pred_neg - noise_pred_nocond
+ perp = ((torch.mul(pos, neg).sum())/(torch.norm(neg)**2)) * neg
+ perp_neg = perp * neg_scale
+ cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg)
+ cfg_result = x - cfg_result
+ return cfg_result
+
+ m.set_model_sampler_cfg_function(cfg_function)
+
+ return (m, )
+
+
+NODE_CLASS_MAPPINGS = {
+ "PerpNeg": PerpNeg,
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "PerpNeg": "Perp-Neg",
+}
diff --git a/comfy_extras/nodes/nodes_post_processing.py b/comfy_extras/nodes/nodes_post_processing.py
index 9213eb546..334214ad1 100644
--- a/comfy_extras/nodes/nodes_post_processing.py
+++ b/comfy_extras/nodes/nodes_post_processing.py
@@ -226,7 +226,7 @@ class Sharpen:
batch_size, height, width, channels = image.shape
kernel_size = sharpen_radius * 2 + 1
- kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10)
+ kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
center = kernel_size // 2
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
diff --git a/comfy_extras/nodes/nodes_rebatch.py b/comfy_extras/nodes/nodes_rebatch.py
index 88a4ebe29..3010fbd4b 100644
--- a/comfy_extras/nodes/nodes_rebatch.py
+++ b/comfy_extras/nodes/nodes_rebatch.py
@@ -99,10 +99,40 @@ class LatentRebatch:
return (output_list,)
+class ImageRebatch:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "images": ("IMAGE",),
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
+ }}
+ RETURN_TYPES = ("IMAGE",)
+ INPUT_IS_LIST = True
+ OUTPUT_IS_LIST = (True, )
+
+ FUNCTION = "rebatch"
+
+ CATEGORY = "image/batch"
+
+ def rebatch(self, images, batch_size):
+ batch_size = batch_size[0]
+
+ output_list = []
+ all_images = []
+ for img in images:
+ for i in range(img.shape[0]):
+ all_images.append(img[i:i+1])
+
+ for i in range(0, len(all_images), batch_size):
+ output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
+
+ return (output_list,)
+
NODE_CLASS_MAPPINGS = {
"RebatchLatents": LatentRebatch,
+ "RebatchImages": ImageRebatch,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"RebatchLatents": "Rebatch Latents",
-}
\ No newline at end of file
+ "RebatchImages": "Rebatch Images",
+}
diff --git a/comfy_extras/nodes/nodes_sag.py b/comfy_extras/nodes/nodes_sag.py
new file mode 100644
index 000000000..66606e328
--- /dev/null
+++ b/comfy_extras/nodes/nodes_sag.py
@@ -0,0 +1,168 @@
+import torch
+from torch import einsum
+import torch.nn.functional as F
+import math
+
+from einops import rearrange, repeat
+import os
+from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
+from comfy import samplers
+
+# from comfy/ldm/modules/attention.py
+# but modified to return attention scores as well as output
+def attention_basic_with_sim(q, k, v, heads, mask=None):
+ b, _, dim_head = q.shape
+ dim_head //= heads
+ scale = dim_head ** -0.5
+
+ h = heads
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, -1, heads, dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * heads, -1, dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ # force cast to fp32 to avoid overflowing
+ if _ATTN_PRECISION =="fp32":
+ sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
+ else:
+ sim = einsum('b i d, b j d -> b i j', q, k) * scale
+
+ del q, k
+
+ if mask is not None:
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, heads, -1, dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, -1, heads * dim_head)
+ )
+ return (out, sim)
+
+def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
+ # reshape and GAP the attention map
+ _, hw1, hw2 = attn.shape
+ b, _, lh, lw = x0.shape
+ attn = attn.reshape(b, -1, hw1, hw2)
+ # Global Average Pool
+ mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
+ ratio = math.ceil(math.sqrt(lh * lw / hw1))
+ mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
+
+ # Reshape
+ mask = (
+ mask.reshape(b, *mid_shape)
+ .unsqueeze(1)
+ .type(attn.dtype)
+ )
+ # Upsample
+ mask = F.interpolate(mask, (lh, lw))
+
+ blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
+ blurred = blurred * mask + x0 * (1 - mask)
+ return blurred
+
+def gaussian_blur_2d(img, kernel_size, sigma):
+ ksize_half = (kernel_size - 1) * 0.5
+
+ x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
+
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
+
+ x_kernel = pdf / pdf.sum()
+ x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
+
+ kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
+ kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
+
+ padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
+
+ img = F.pad(img, padding, mode="reflect")
+ img = F.conv2d(img, kernel2d, groups=img.shape[-3])
+ return img
+
+class SelfAttentionGuidance:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "model": ("MODEL",),
+ "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}),
+ "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
+ }}
+ RETURN_TYPES = ("MODEL",)
+ FUNCTION = "patch"
+
+ CATEGORY = "_for_testing"
+
+ def patch(self, model, scale, blur_sigma):
+ m = model.clone()
+
+ attn_scores = None
+
+ # TODO: make this work properly with chunked batches
+ # currently, we can only save the attn from one UNet call
+ def attn_and_record(q, k, v, extra_options):
+ nonlocal attn_scores
+ # if uncond, save the attention scores
+ heads = extra_options["n_heads"]
+ cond_or_uncond = extra_options["cond_or_uncond"]
+ b = q.shape[0] // len(cond_or_uncond)
+ if 1 in cond_or_uncond:
+ uncond_index = cond_or_uncond.index(1)
+ # do the entire attention operation, but save the attention scores to attn_scores
+ (out, sim) = attention_basic_with_sim(q, k, v, heads=heads)
+ # when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
+ n_slices = heads * b
+ attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
+ return out
+ else:
+ return optimized_attention(q, k, v, heads=heads)
+
+ def post_cfg_function(args):
+ nonlocal attn_scores
+ uncond_attn = attn_scores
+
+ sag_scale = scale
+ sag_sigma = blur_sigma
+ sag_threshold = 1.0
+ model = args["model"]
+ uncond_pred = args["uncond_denoised"]
+ uncond = args["uncond"]
+ cfg_result = args["denoised"]
+ sigma = args["sigma"]
+ model_options = args["model_options"]
+ x = args["input"]
+
+ # create the adversarially blurred image
+ degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
+ degraded_noised = degraded + x - uncond_pred
+ # call into the UNet
+ (sag, _) = samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
+ return cfg_result + (degraded - sag) * sag_scale
+
+ m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
+
+ # from diffusers:
+ # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
+ m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
+
+ return (m, )
+
+NODE_CLASS_MAPPINGS = {
+ "SelfAttentionGuidance": SelfAttentionGuidance,
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "SelfAttentionGuidance": "Self-Attention Guidance",
+}
diff --git a/comfy_extras/nodes/nodes_sdupscale.py b/comfy_extras/nodes/nodes_sdupscale.py
new file mode 100644
index 000000000..4502fcd84
--- /dev/null
+++ b/comfy_extras/nodes/nodes_sdupscale.py
@@ -0,0 +1,46 @@
+import torch
+from comfy import utils
+
+class SD_4XUpscale_Conditioning:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "images": ("IMAGE",),
+ "positive": ("CONDITIONING",),
+ "negative": ("CONDITIONING",),
+ "scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
+ }}
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
+ RETURN_NAMES = ("positive", "negative", "latent")
+
+ FUNCTION = "encode"
+
+ CATEGORY = "conditioning/upscale_diffusion"
+
+ def encode(self, images, positive, negative, scale_ratio, noise_augmentation):
+ width = max(1, round(images.shape[-2] * scale_ratio))
+ height = max(1, round(images.shape[-3] * scale_ratio))
+
+ pixels = utils.common_upscale((images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center")
+
+ out_cp = []
+ out_cn = []
+
+ for t in positive:
+ n = [t[0], t[1].copy()]
+ n[1]['concat_image'] = pixels
+ n[1]['noise_augmentation'] = noise_augmentation
+ out_cp.append(n)
+
+ for t in negative:
+ n = [t[0], t[1].copy()]
+ n[1]['concat_image'] = pixels
+ n[1]['noise_augmentation'] = noise_augmentation
+ out_cn.append(n)
+
+ latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
+ return (out_cp, out_cn, {"samples":latent})
+
+NODE_CLASS_MAPPINGS = {
+ "SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning,
+}
diff --git a/comfy_extras/nodes/nodes_stable3d.py b/comfy_extras/nodes/nodes_stable3d.py
new file mode 100644
index 000000000..7aa6ec858
--- /dev/null
+++ b/comfy_extras/nodes/nodes_stable3d.py
@@ -0,0 +1,61 @@
+import torch
+import comfy.utils
+
+from comfy.nodes.common import MAX_RESOLUTION
+from comfy import utils
+
+
+def camera_embeddings(elevation, azimuth):
+ elevation = torch.as_tensor([elevation])
+ azimuth = torch.as_tensor([azimuth])
+ embeddings = torch.stack(
+ [
+ torch.deg2rad(
+ (90 - elevation) - (90)
+ ), # Zero123 polar is 90-elevation
+ torch.sin(torch.deg2rad(azimuth)),
+ torch.cos(torch.deg2rad(azimuth)),
+ torch.deg2rad(
+ 90 - torch.full_like(elevation, 0)
+ ),
+ ], dim=-1).unsqueeze(1)
+
+ return embeddings
+
+
+class StableZero123_Conditioning:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "clip_vision": ("CLIP_VISION",),
+ "init_image": ("IMAGE",),
+ "vae": ("VAE",),
+ "width": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
+ "height": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
+ "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
+ "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
+ }}
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
+ RETURN_NAMES = ("positive", "negative", "latent")
+
+ FUNCTION = "encode"
+
+ CATEGORY = "conditioning/3d_models"
+
+ def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth):
+ output = clip_vision.encode_image(init_image)
+ pooled = output.image_embeds.unsqueeze(0)
+ pixels = utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
+ encode_pixels = pixels[:,:,:,:3]
+ t = vae.encode(encode_pixels)
+ cam_embeds = camera_embeddings(elevation, azimuth)
+ cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1)
+
+ positive = [[cond, {"concat_latent_image": t}]]
+ negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
+ latent = torch.zeros([batch_size, 4, height // 8, width // 8])
+ return (positive, negative, {"samples":latent})
+
+NODE_CLASS_MAPPINGS = {
+ "StableZero123_Conditioning": StableZero123_Conditioning,
+}
diff --git a/requirements.txt b/requirements.txt
index 9bacddc13..df6d90d44 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,6 +3,7 @@ torchaudio
torchvision
torchdiffeq>=0.2.3
torchsde>=0.2.6
+torchvision
einops>=0.6.0
open-clip-torch>=2.16.0
transformers>=4.29.1
diff --git a/setup.py b/setup.py
index e69d6146c..ca79d0db8 100644
--- a/setup.py
+++ b/setup.py
@@ -28,18 +28,18 @@ version = '0.0.1'
"""
The package index to the torch built with AMD ROCm.
"""
-amd_torch_index = "https://download.pytorch.org/whl/rocm5.6"
+amd_torch_index = ("https://download.pytorch.org/whl/rocm5.6", "https://download.pytorch.org/whl/nightly/rocm5.7")
"""
The package index to torch built with CUDA.
Observe the CUDA version is in this URL.
"""
-nvidia_torch_index = "https://download.pytorch.org/whl/cu121"
+nvidia_torch_index = ("https://download.pytorch.org/whl/cu121", "https://download.pytorch.org/whl/nightly/cu121")
"""
The package index to torch built against CPU features.
"""
-cpu_torch_index = "https://download.pytorch.org/whl/cpu"
+cpu_torch_index = ("https://download.pytorch.org/whl/cpu", "https://download.pytorch.org/whl/nightly/cpu")
# xformers not required for new torch
@@ -102,11 +102,11 @@ def _is_linux_arm64():
def dependencies() -> List[str]:
_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements.txt")).readlines()
- # todo: also add all plugin dependencies
_alternative_indices = [amd_torch_index, nvidia_torch_index]
session = PipSession()
- index_urls = ['https://pypi.org/simple']
+ # (stable, nightly) tuple
+ index_urls = [('https://pypi.org/simple', 'https://pypi.org/simple')]
# prefer nvidia over AMD because AM5/iGPU systems will have a valid ROCm device
if _is_nvidia():
index_urls += [nvidia_torch_index]
@@ -118,6 +118,13 @@ def dependencies() -> List[str]:
if len(index_urls) == 1:
return _dependencies
+ if sys.version_info >= (3, 12):
+ # use the nightlies
+ index_urls = [nightly for (_, nightly) in index_urls]
+ _alternative_indices = [nightly for (_, nightly) in _alternative_indices]
+ else:
+ index_urls = [stable for (stable, _) in index_urls]
+ _alternative_indices = [stable for (stable, _) in _alternative_indices]
try:
# pip 23
finder = PackageFinder.create(LinkCollector(session, SearchScope([], index_urls, no_index=False)),
@@ -149,7 +156,7 @@ setup(
description="",
author="",
version=version,
- python_requires=">=3.9,<3.12",
+ python_requires=">=3.9,<3.13",
# todo: figure out how to include the web directory to eventually let main live inside the package
# todo: see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ for more about adding plugins
packages=find_packages(exclude=[] if is_editable else ['custom_nodes']),
diff --git a/tests-ui/afterSetup.js b/tests-ui/afterSetup.js
new file mode 100644
index 000000000..983f3af64
--- /dev/null
+++ b/tests-ui/afterSetup.js
@@ -0,0 +1,9 @@
+const { start } = require("./utils");
+const lg = require("./utils/litegraph");
+
+// Load things once per test file before to ensure its all warmed up for the tests
+beforeAll(async () => {
+ lg.setup(global);
+ await start({ resetEnv: true });
+ lg.teardown(global);
+});
diff --git a/tests-ui/jest.config.js b/tests-ui/jest.config.js
index b5a5d646d..86fff5057 100644
--- a/tests-ui/jest.config.js
+++ b/tests-ui/jest.config.js
@@ -2,8 +2,10 @@
const config = {
testEnvironment: "jsdom",
setupFiles: ["./globalSetup.js"],
+ setupFilesAfterEnv: ["./afterSetup.js"],
clearMocks: true,
resetModules: true,
+ testTimeout: 10000
};
module.exports = config;
diff --git a/tests-ui/tests/extensions.test.js b/tests-ui/tests/extensions.test.js
index b82e55c32..159e5113a 100644
--- a/tests-ui/tests/extensions.test.js
+++ b/tests-ui/tests/extensions.test.js
@@ -52,7 +52,7 @@ describe("extensions", () => {
const nodeNames = Object.keys(defs);
const nodeCount = nodeNames.length;
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
- for (let i = 0; i < nodeCount; i++) {
+ for (let i = 0; i < 10; i++) {
// It should be send the JS class and the original JSON definition
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0];
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1];
@@ -133,7 +133,7 @@ describe("extensions", () => {
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2);
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1);
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2);
- });
+ }, 15000);
it("allows custom nodeDefs and widgets to be registered", async () => {
const widgetMock = jest.fn((node, inputName, inputData, app) => {
diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js
index ce54c1154..e6ebedd91 100644
--- a/tests-ui/tests/groupNode.test.js
+++ b/tests-ui/tests/groupNode.test.js
@@ -1,7 +1,7 @@
// @ts-check
///
-const { start, createDefaultWorkflow } = require("../utils");
+const { start, createDefaultWorkflow, getNodeDef, checkBeforeAndAfterReload } = require("../utils");
const lg = require("../utils/litegraph");
describe("group node", () => {
@@ -273,7 +273,7 @@ describe("group node", () => {
let reroutes = [];
let prevNode = nodes.ckpt;
- for(let i = 0; i < 5; i++) {
+ for (let i = 0; i < 5; i++) {
const reroute = ez.Reroute();
prevNode.outputs[0].connectTo(reroute.inputs[0]);
prevNode = reroute;
@@ -283,7 +283,7 @@ describe("group node", () => {
const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]);
expect((await graph.toPrompt()).output).toEqual(getOutput());
-
+
group.menu["Convert to nodes"].call();
expect((await graph.toPrompt()).output).toEqual(getOutput());
});
@@ -383,6 +383,43 @@ describe("group node", () => {
getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id])
);
});
+ test("groups can connect to each other via internal reroutes", async () => {
+ const { ez, graph, app } = await start();
+
+ const latent = ez.EmptyLatentImage();
+ const vae = ez.VAELoader();
+ const latentReroute = ez.Reroute();
+ const vaeReroute = ez.Reroute();
+
+ latent.outputs[0].connectTo(latentReroute.inputs[0]);
+ vae.outputs[0].connectTo(vaeReroute.inputs[0]);
+
+ const group1 = await convertToGroup(app, graph, "test", [latentReroute, vaeReroute]);
+ group1.menu.Clone.call();
+ expect(app.graph._nodes).toHaveLength(4);
+ const group2 = graph.find(app.graph._nodes[3]);
+ expect(group2.node.type).toEqual("workflow/test");
+ expect(group2.id).not.toEqual(group1.id);
+
+ group1.outputs.VAE.connectTo(group2.inputs.VAE);
+ group1.outputs.LATENT.connectTo(group2.inputs.LATENT);
+
+ const decode = ez.VAEDecode(group2.outputs.LATENT, group2.outputs.VAE);
+ const preview = ez.PreviewImage(decode.outputs[0]);
+
+ const output = {
+ [latent.id]: { inputs: { width: 512, height: 512, batch_size: 1 }, class_type: "EmptyLatentImage" },
+ [vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" },
+ [decode.id]: { inputs: { samples: [latent.id + "", 0], vae: [vae.id + "", 0] }, class_type: "VAEDecode" },
+ [preview.id]: { inputs: { images: [decode.id + "", 0] }, class_type: "PreviewImage" },
+ };
+ expect((await graph.toPrompt()).output).toEqual(output);
+
+ // Ensure missing connections dont cause errors
+ group2.inputs.VAE.disconnect();
+ delete output[decode.id].inputs.vae;
+ expect((await graph.toPrompt()).output).toEqual(output);
+ });
test("displays generated image on group node", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
@@ -642,6 +679,55 @@ describe("group node", () => {
2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
});
});
+ test("correctly handles widget inputs", async () => {
+ const { ez, graph, app } = await start();
+ const upscaleMethods = (await getNodeDef("ImageScaleBy")).input.required["upscale_method"][0];
+
+ const image = ez.LoadImage();
+ const scale1 = ez.ImageScaleBy(image.outputs[0]);
+ const scale2 = ez.ImageScaleBy(image.outputs[0]);
+ const preview1 = ez.PreviewImage(scale1.outputs[0]);
+ const preview2 = ez.PreviewImage(scale2.outputs[0]);
+ scale1.widgets.upscale_method.value = upscaleMethods[1];
+ scale1.widgets.upscale_method.convertToInput();
+
+ const group = await convertToGroup(app, graph, "test", [scale1, scale2]);
+ expect(group.inputs.length).toBe(3);
+ expect(group.inputs[0].input.type).toBe("IMAGE");
+ expect(group.inputs[1].input.type).toBe("IMAGE");
+ expect(group.inputs[2].input.type).toBe("COMBO");
+
+ // Ensure links are maintained
+ expect(group.inputs[0].connection?.originNode?.id).toBe(image.id);
+ expect(group.inputs[1].connection?.originNode?.id).toBe(image.id);
+ expect(group.inputs[2].connection).toBeFalsy();
+
+ // Ensure primitive gets correct type
+ const primitive = ez.PrimitiveNode();
+ primitive.outputs[0].connectTo(group.inputs[2]);
+ expect(primitive.widgets.value.widget.options.values).toBe(upscaleMethods);
+ expect(primitive.widgets.value.value).toBe(upscaleMethods[1]); // Ensure value is copied
+ primitive.widgets.value.value = upscaleMethods[1];
+
+ await checkBeforeAndAfterReload(graph, async (r) => {
+ const scale1id = r ? `${group.id}:0` : scale1.id;
+ const scale2id = r ? `${group.id}:1` : scale2.id;
+ // Ensure widget value is applied to prompt
+ expect((await graph.toPrompt()).output).toStrictEqual({
+ [image.id]: { inputs: { image: "example.png", upload: "image" }, class_type: "LoadImage" },
+ [scale1id]: {
+ inputs: { upscale_method: upscaleMethods[1], scale_by: 1, image: [`${image.id}`, 0] },
+ class_type: "ImageScaleBy",
+ },
+ [scale2id]: {
+ inputs: { upscale_method: "nearest-exact", scale_by: 1, image: [`${image.id}`, 0] },
+ class_type: "ImageScaleBy",
+ },
+ [preview1.id]: { inputs: { images: [`${scale1id}`, 0] }, class_type: "PreviewImage" },
+ [preview2.id]: { inputs: { images: [`${scale2id}`, 0] }, class_type: "PreviewImage" },
+ });
+ });
+ });
test("adds widgets in node execution order", async () => {
const { ez, graph, app } = await start();
const scale = ez.LatentUpscale();
@@ -815,4 +901,105 @@ describe("group node", () => {
expect(p2.widgets.control_after_generate.value).toBe("randomize");
expect(p2.widgets.control_filter_list.value).toBe("/.+/");
});
+ test("internal reroutes work with converted inputs and merge options", async () => {
+ const { ez, graph, app } = await start();
+ const vae = ez.VAELoader();
+ const latent = ez.EmptyLatentImage();
+ const decode = ez.VAEDecode(latent.outputs.LATENT, vae.outputs.VAE);
+ const scale = ez.ImageScale(decode.outputs.IMAGE);
+ ez.PreviewImage(scale.outputs.IMAGE);
+
+ const r1 = ez.Reroute();
+ const r2 = ez.Reroute();
+
+ latent.widgets.width.value = 64;
+ latent.widgets.height.value = 128;
+
+ latent.widgets.width.convertToInput();
+ latent.widgets.height.convertToInput();
+ latent.widgets.batch_size.convertToInput();
+
+ scale.widgets.width.convertToInput();
+ scale.widgets.height.convertToInput();
+
+ r1.inputs[0].input.label = "hbw";
+ r1.outputs[0].connectTo(latent.inputs.height);
+ r1.outputs[0].connectTo(latent.inputs.batch_size);
+ r1.outputs[0].connectTo(scale.inputs.width);
+
+ r2.inputs[0].input.label = "wh";
+ r2.outputs[0].connectTo(latent.inputs.width);
+ r2.outputs[0].connectTo(scale.inputs.height);
+
+ const group = await convertToGroup(app, graph, "test", [r1, r2, latent, decode, scale]);
+
+ expect(group.inputs[0].input.type).toBe("VAE");
+ expect(group.inputs[1].input.type).toBe("INT");
+ expect(group.inputs[2].input.type).toBe("INT");
+
+ const p1 = ez.PrimitiveNode();
+ const p2 = ez.PrimitiveNode();
+ p1.outputs[0].connectTo(group.inputs[1]);
+ p2.outputs[0].connectTo(group.inputs[2]);
+
+ expect(p1.widgets.value.widget.options?.min).toBe(16); // width/height min
+ expect(p1.widgets.value.widget.options?.max).toBe(4096); // batch max
+ expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
+
+ expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min
+ expect(p2.widgets.value.widget.options?.max).toBe(8192); // width/height max
+ expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
+
+ expect(p1.widgets.value.value).toBe(128);
+ expect(p2.widgets.value.value).toBe(64);
+
+ p1.widgets.value.value = 16;
+ p2.widgets.value.value = 32;
+
+ await checkBeforeAndAfterReload(graph, async (r) => {
+ const id = (v) => (r ? `${group.id}:` : "") + v;
+ expect((await graph.toPrompt()).output).toStrictEqual({
+ 1: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" },
+ [id(2)]: { inputs: { width: 32, height: 16, batch_size: 16 }, class_type: "EmptyLatentImage" },
+ [id(3)]: { inputs: { samples: [id(2), 0], vae: ["1", 0] }, class_type: "VAEDecode" },
+ [id(4)]: {
+ inputs: { upscale_method: "nearest-exact", width: 16, height: 32, crop: "disabled", image: [id(3), 0] },
+ class_type: "ImageScale",
+ },
+ 5: { inputs: { images: [id(4), 0] }, class_type: "PreviewImage" },
+ });
+ });
+ });
+ test("converted inputs with linked widgets map values correctly on creation", async () => {
+ const { ez, graph, app } = await start();
+ const k1 = ez.KSampler();
+ const k2 = ez.KSampler();
+ k1.widgets.seed.convertToInput();
+ k2.widgets.seed.convertToInput();
+
+ const rr = ez.Reroute();
+ rr.outputs[0].connectTo(k1.inputs.seed);
+ rr.outputs[0].connectTo(k2.inputs.seed);
+
+ const group = await convertToGroup(app, graph, "test", [k1, k2, rr]);
+ expect(group.widgets.steps.value).toBe(20);
+ expect(group.widgets.cfg.value).toBe(8);
+ expect(group.widgets.scheduler.value).toBe("normal");
+ expect(group.widgets["KSampler steps"].value).toBe(20);
+ expect(group.widgets["KSampler cfg"].value).toBe(8);
+ expect(group.widgets["KSampler scheduler"].value).toBe("normal");
+ });
+ test("allow multiple of the same node type to be added", async () => {
+ const { ez, graph, app } = await start();
+ const nodes = [...Array(10)].map(() => ez.ImageScaleBy());
+ const group = await convertToGroup(app, graph, "test", nodes);
+ expect(group.inputs.length).toBe(10);
+ expect(group.outputs.length).toBe(10);
+ expect(group.widgets.length).toBe(20);
+ expect(group.widgets.map((w) => w.widget.name)).toStrictEqual(
+ [...Array(10)]
+ .map((_, i) => `${i > 0 ? "ImageScaleBy " : ""}${i > 1 ? i + " " : ""}`)
+ .flatMap((p) => [`${p}upscale_method`, `${p}scale_by`])
+ );
+ });
});
diff --git a/tests-ui/tests/widgetInputs.test.js b/tests-ui/tests/widgetInputs.test.js
index 8e191adf0..67e3fa341 100644
--- a/tests-ui/tests/widgetInputs.test.js
+++ b/tests-ui/tests/widgetInputs.test.js
@@ -1,7 +1,13 @@
// @ts-check
///
-const { start, makeNodeDef, checkBeforeAndAfterReload, assertNotNullOrUndefined } = require("../utils");
+const {
+ start,
+ makeNodeDef,
+ checkBeforeAndAfterReload,
+ assertNotNullOrUndefined,
+ createDefaultWorkflow,
+} = require("../utils");
const lg = require("../utils/litegraph");
/**
@@ -36,7 +42,7 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWi
if (controlWidgetCount) {
const controlWidget = primitive.widgets.control_after_generate;
expect(controlWidget.widget.type).toBe("combo");
- if(widgetType === "combo") {
+ if (widgetType === "combo") {
const filterWidget = primitive.widgets.control_filter_list;
expect(filterWidget.widget.type).toBe("string");
}
@@ -308,8 +314,8 @@ describe("widget inputs", () => {
const { ez } = await start({
mockNodeDefs: {
...makeNodeDef("TestNode1", {}, [["A", "B"]]),
- ...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true}] }),
- ...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true}] }),
+ ...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true }] }),
+ ...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true }] }),
},
});
@@ -330,7 +336,7 @@ describe("widget inputs", () => {
const n1 = ez.TestNode1();
n1.widgets.example.convertToInput();
- const p = ez.PrimitiveNode()
+ const p = ez.PrimitiveNode();
p.outputs[0].connectTo(n1.inputs[0]);
const value = p.widgets.value;
@@ -380,7 +386,7 @@ describe("widget inputs", () => {
// Check random
control.value = "randomize";
filter.value = "/D/";
- for(let i = 0; i < 100; i++) {
+ for (let i = 0; i < 100; i++) {
control["afterQueued"]();
expect(value.value === "D" || value.value === "DD").toBeTruthy();
}
@@ -392,4 +398,160 @@ describe("widget inputs", () => {
control["afterQueued"]();
expect(value.value).toBe("B");
});
+
+ describe("reroutes", () => {
+ async function checkOutput(graph, values) {
+ expect((await graph.toPrompt()).output).toStrictEqual({
+ 1: { inputs: { ckpt_name: "model1.safetensors" }, class_type: "CheckpointLoaderSimple" },
+ 2: { inputs: { text: "positive", clip: ["1", 1] }, class_type: "CLIPTextEncode" },
+ 3: { inputs: { text: "negative", clip: ["1", 1] }, class_type: "CLIPTextEncode" },
+ 4: {
+ inputs: { width: values.width ?? 512, height: values.height ?? 512, batch_size: values?.batch_size ?? 1 },
+ class_type: "EmptyLatentImage",
+ },
+ 5: {
+ inputs: {
+ seed: 0,
+ steps: 20,
+ cfg: 8,
+ sampler_name: "euler",
+ scheduler: values?.scheduler ?? "normal",
+ denoise: 1,
+ model: ["1", 0],
+ positive: ["2", 0],
+ negative: ["3", 0],
+ latent_image: ["4", 0],
+ },
+ class_type: "KSampler",
+ },
+ 6: { inputs: { samples: ["5", 0], vae: ["1", 2] }, class_type: "VAEDecode" },
+ 7: {
+ inputs: { filename_prefix: values.filename_prefix ?? "ComfyUI", images: ["6", 0] },
+ class_type: "SaveImage",
+ },
+ });
+ }
+
+ async function waitForWidget(node) {
+ // widgets are created slightly after the graph is ready
+ // hard to find an exact hook to get these so just wait for them to be ready
+ for (let i = 0; i < 10; i++) {
+ await new Promise((r) => setTimeout(r, 10));
+ if (node.widgets?.value) {
+ return;
+ }
+ }
+ }
+
+ it("can connect primitive via a reroute path to a widget input", async () => {
+ const { ez, graph } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+
+ nodes.empty.widgets.width.convertToInput();
+ nodes.sampler.widgets.scheduler.convertToInput();
+ nodes.save.widgets.filename_prefix.convertToInput();
+
+ let widthReroute = ez.Reroute();
+ let schedulerReroute = ez.Reroute();
+ let fileReroute = ez.Reroute();
+
+ let widthNext = widthReroute;
+ let schedulerNext = schedulerReroute;
+ let fileNext = fileReroute;
+
+ for (let i = 0; i < 5; i++) {
+ let next = ez.Reroute();
+ widthNext.outputs[0].connectTo(next.inputs[0]);
+ widthNext = next;
+
+ next = ez.Reroute();
+ schedulerNext.outputs[0].connectTo(next.inputs[0]);
+ schedulerNext = next;
+
+ next = ez.Reroute();
+ fileNext.outputs[0].connectTo(next.inputs[0]);
+ fileNext = next;
+ }
+
+ widthNext.outputs[0].connectTo(nodes.empty.inputs.width);
+ schedulerNext.outputs[0].connectTo(nodes.sampler.inputs.scheduler);
+ fileNext.outputs[0].connectTo(nodes.save.inputs.filename_prefix);
+
+ let widthPrimitive = ez.PrimitiveNode();
+ let schedulerPrimitive = ez.PrimitiveNode();
+ let filePrimitive = ez.PrimitiveNode();
+
+ widthPrimitive.outputs[0].connectTo(widthReroute.inputs[0]);
+ schedulerPrimitive.outputs[0].connectTo(schedulerReroute.inputs[0]);
+ filePrimitive.outputs[0].connectTo(fileReroute.inputs[0]);
+ expect(widthPrimitive.widgets.value.value).toBe(512);
+ widthPrimitive.widgets.value.value = 1024;
+ expect(schedulerPrimitive.widgets.value.value).toBe("normal");
+ schedulerPrimitive.widgets.value.value = "simple";
+ expect(filePrimitive.widgets.value.value).toBe("ComfyUI");
+ filePrimitive.widgets.value.value = "ComfyTest";
+
+ await checkBeforeAndAfterReload(graph, async () => {
+ widthPrimitive = graph.find(widthPrimitive);
+ schedulerPrimitive = graph.find(schedulerPrimitive);
+ filePrimitive = graph.find(filePrimitive);
+ await waitForWidget(filePrimitive);
+ expect(widthPrimitive.widgets.length).toBe(2);
+ expect(schedulerPrimitive.widgets.length).toBe(3);
+ expect(filePrimitive.widgets.length).toBe(1);
+
+ await checkOutput(graph, {
+ width: 1024,
+ scheduler: "simple",
+ filename_prefix: "ComfyTest",
+ });
+ });
+ });
+ it("can connect primitive via a reroute path to multiple widget inputs", async () => {
+ const { ez, graph } = await start();
+ const nodes = createDefaultWorkflow(ez, graph);
+
+ nodes.empty.widgets.width.convertToInput();
+ nodes.empty.widgets.height.convertToInput();
+ nodes.empty.widgets.batch_size.convertToInput();
+
+ let reroute = ez.Reroute();
+ let prevReroute = reroute;
+ for (let i = 0; i < 5; i++) {
+ const next = ez.Reroute();
+ prevReroute.outputs[0].connectTo(next.inputs[0]);
+ prevReroute = next;
+ }
+
+ const r1 = ez.Reroute(prevReroute.outputs[0]);
+ const r2 = ez.Reroute(prevReroute.outputs[0]);
+ const r3 = ez.Reroute(r2.outputs[0]);
+ const r4 = ez.Reroute(r2.outputs[0]);
+
+ r1.outputs[0].connectTo(nodes.empty.inputs.width);
+ r3.outputs[0].connectTo(nodes.empty.inputs.height);
+ r4.outputs[0].connectTo(nodes.empty.inputs.batch_size);
+
+ let primitive = ez.PrimitiveNode();
+ primitive.outputs[0].connectTo(reroute.inputs[0]);
+ expect(primitive.widgets.value.value).toBe(1);
+ primitive.widgets.value.value = 64;
+
+ await checkBeforeAndAfterReload(graph, async (r) => {
+ primitive = graph.find(primitive);
+ await waitForWidget(primitive);
+
+ // Ensure widget configs are merged
+ expect(primitive.widgets.value.widget.options?.min).toBe(16); // width/height min
+ expect(primitive.widgets.value.widget.options?.max).toBe(4096); // batch max
+ expect(primitive.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
+
+ await checkOutput(graph, {
+ width: 64,
+ height: 64,
+ batch_size: 64,
+ });
+ });
+ });
+ });
});
diff --git a/tests-ui/utils/ezgraph.js b/tests-ui/utils/ezgraph.js
index 898b82db0..8a55246ee 100644
--- a/tests-ui/utils/ezgraph.js
+++ b/tests-ui/utils/ezgraph.js
@@ -78,6 +78,14 @@ export class EzInput extends EzSlot {
this.input = input;
}
+ get connection() {
+ const link = this.node.node.inputs?.[this.index]?.link;
+ if (link == null) {
+ return null;
+ }
+ return new EzConnection(this.node.app, this.node.app.graph.links[link]);
+ }
+
disconnect() {
this.node.node.disconnectInput(this.index);
}
@@ -117,7 +125,7 @@ export class EzOutput extends EzSlot {
const inp = input.input;
const inName = inp.name || inp.label || inp.type;
throw new Error(
- `Connecting from ${input.node.node.type}[${inName}#${input.index}] -> ${this.node.node.type}[${
+ `Connecting from ${input.node.node.type}#${input.node.id}[${inName}#${input.index}] -> ${this.node.node.type}#${this.node.id}[${
this.output.name ?? this.output.type
}#${this.index}] failed.`
);
@@ -179,6 +187,7 @@ export class EzWidget {
set value(v) {
this.widget.value = v;
+ this.widget.callback?.call?.(this.widget, v)
}
get isConvertedToInput() {
@@ -319,7 +328,7 @@ export class EzGraph {
}
stringify() {
- return JSON.stringify(this.app.graph.serialize(), undefined, "\t");
+ return JSON.stringify(this.app.graph.serialize(), undefined);
}
/**
diff --git a/tests-ui/utils/index.js b/tests-ui/utils/index.js
index 3a018f566..6a08e8594 100644
--- a/tests-ui/utils/index.js
+++ b/tests-ui/utils/index.js
@@ -104,3 +104,12 @@ export function createDefaultWorkflow(ez, graph) {
return { ckpt, pos, neg, empty, sampler, decode, save };
}
+
+export async function getNodeDefs() {
+ const { api } = require("../../web/scripts/api");
+ return api.getNodeDefs();
+}
+
+export async function getNodeDef(nodeId) {
+ return (await getNodeDefs())[nodeId];
+}
\ No newline at end of file
diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js
index 6766f356d..4cf1f7621 100644
--- a/web/extensions/core/groupNode.js
+++ b/web/extensions/core/groupNode.js
@@ -174,6 +174,11 @@ export class GroupNodeConfig {
node.index = i;
this.processNode(node, seenInputs, seenOutputs);
}
+
+ for (const p of this.#convertedToProcess) {
+ p();
+ }
+ this.#convertedToProcess = null;
await app.registerNodeDef("workflow/" + this.name, this.nodeDef);
}
@@ -192,7 +197,10 @@ export class GroupNodeConfig {
if (!this.linksFrom[sourceNodeId]) {
this.linksFrom[sourceNodeId] = {};
}
- this.linksFrom[sourceNodeId][sourceNodeSlot] = l;
+ if (!this.linksFrom[sourceNodeId][sourceNodeSlot]) {
+ this.linksFrom[sourceNodeId][sourceNodeSlot] = [];
+ }
+ this.linksFrom[sourceNodeId][sourceNodeSlot].push(l);
if (!this.linksTo[targetNodeId]) {
this.linksTo[targetNodeId] = {};
@@ -230,11 +238,11 @@ export class GroupNodeConfig {
// Skip as its not linked
if (!linksFrom) return;
- let type = linksFrom["0"][5];
+ let type = linksFrom["0"][0][5];
if (type === "COMBO") {
// Use the array items
const source = node.outputs[0].widget.name;
- const fromTypeName = this.nodeData.nodes[linksFrom["0"][2]].type;
+ const fromTypeName = this.nodeData.nodes[linksFrom["0"][0][2]].type;
const fromType = globalDefs[fromTypeName];
const input = fromType.input.required[source] ?? fromType.input.optional[source];
type = input[0];
@@ -258,10 +266,33 @@ export class GroupNodeConfig {
return null;
}
+ let config = {};
let rerouteType = "*";
if (linksFrom) {
- const [, , id, slot] = linksFrom["0"];
- rerouteType = this.nodeData.nodes[id].inputs[slot].type;
+ for (const [, , id, slot] of linksFrom["0"]) {
+ const node = this.nodeData.nodes[id];
+ const input = node.inputs[slot];
+ if (rerouteType === "*") {
+ rerouteType = input.type;
+ }
+ if (input.widget) {
+ const targetDef = globalDefs[node.type];
+ const targetWidget =
+ targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name];
+
+ const widget = [targetWidget[0], config];
+ const res = mergeIfValid(
+ {
+ widget,
+ },
+ targetWidget,
+ false,
+ null,
+ widget
+ );
+ config = res?.customConfig ?? config;
+ }
+ }
} else if (linksTo) {
const [id, slot] = linksTo["0"];
rerouteType = this.nodeData.nodes[id].outputs[slot].type;
@@ -282,10 +313,11 @@ export class GroupNodeConfig {
}
}
+ config.forceInput = true;
return {
input: {
required: {
- [rerouteType]: [rerouteType, {}],
+ [rerouteType]: [rerouteType, config],
},
},
output: [rerouteType],
@@ -299,16 +331,17 @@ export class GroupNodeConfig {
getInputConfig(node, inputName, seenInputs, config, extra) {
let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName;
+ let key = name;
let prefix = "";
// Special handling for primitive to include the title if it is set rather than just "value"
if ((node.type === "PrimitiveNode" && node.title) || name in seenInputs) {
prefix = `${node.title ?? node.type} `;
- name = `${prefix}${inputName}`;
+ key = name = `${prefix}${inputName}`;
if (name in seenInputs) {
name = `${prefix}${seenInputs[name]} ${inputName}`;
}
}
- seenInputs[name] = (seenInputs[name] ?? 1) + 1;
+ seenInputs[key] = (seenInputs[key] ?? 1) + 1;
if (inputName === "seed" || inputName === "noise_seed") {
if (!extra) extra = {};
@@ -420,10 +453,18 @@ export class GroupNodeConfig {
defaultInput: true,
});
this.nodeDef.input.required[name] = config;
+ this.newToOldWidgetMap[name] = { node, inputName };
+
+ if (!this.oldToNewWidgetMap[node.index]) {
+ this.oldToNewWidgetMap[node.index] = {};
+ }
+ this.oldToNewWidgetMap[node.index][inputName] = name;
+
inputMap[slots.length + i] = this.inputCount++;
}
}
+ #convertedToProcess = [];
processNodeInputs(node, seenInputs, inputs) {
const inputMapping = [];
@@ -434,7 +475,11 @@ export class GroupNodeConfig {
const linksTo = this.linksTo[node.index] ?? {};
const inputMap = (this.oldToNewInputMap[node.index] = {});
this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs);
- this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs);
+
+ // Converted inputs have to be processed after all other nodes as they'll be at the end of the list
+ this.#convertedToProcess.push(() =>
+ this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs)
+ );
return inputMapping;
}
@@ -597,11 +642,19 @@ export class GroupNodeHandler {
const output = this.groupData.newToOldOutputMap[link.origin_slot];
let innerNode = this.innerNodes[output.node.index];
let l;
- while (innerNode.type === "Reroute") {
+ while (innerNode?.type === "Reroute") {
l = innerNode.getInputLink(0);
innerNode = innerNode.getInputNode(0);
}
+ if (!innerNode) {
+ return null;
+ }
+
+ if (l && GroupNodeHandler.isGroupNode(innerNode)) {
+ return innerNode.updateLink(l);
+ }
+
link.origin_id = innerNode.id;
link.origin_slot = l?.origin_slot ?? output.slot;
return link;
@@ -665,6 +718,8 @@ export class GroupNodeHandler {
top = newNode.pos[1];
}
+ if (!newNode.widgets) continue;
+
const map = this.groupData.oldToNewWidgetMap[innerNode.index];
if (map) {
const widgets = Object.keys(map);
@@ -721,7 +776,7 @@ export class GroupNodeHandler {
}
};
- const reconnectOutputs = () => {
+ const reconnectOutputs = (selectedIds) => {
for (let groupOutputId = 0; groupOutputId < node.outputs?.length; groupOutputId++) {
const output = node.outputs[groupOutputId];
if (!output.links) continue;
@@ -861,7 +916,7 @@ export class GroupNodeHandler {
if (innerNode.type === "PrimitiveNode") {
innerNode.primitiveValue = newValue;
const primitiveLinked = this.groupData.primitiveToWidget[old.node.index];
- for (const linked of primitiveLinked) {
+ for (const linked of primitiveLinked ?? []) {
const node = this.innerNodes[linked.nodeId];
const widget = node.widgets.find((w) => w.name === linked.inputName);
@@ -870,6 +925,18 @@ export class GroupNodeHandler {
}
}
continue;
+ } else if (innerNode.type === "Reroute") {
+ const rerouteLinks = this.groupData.linksFrom[old.node.index];
+ for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) {
+ const node = this.innerNodes[targetNodeId];
+ const input = node.inputs[targetSlot];
+ if (input.widget) {
+ const widget = node.widgets?.find((w) => w.name === input.widget.name);
+ if (widget) {
+ widget.value = newValue;
+ }
+ }
+ }
}
const widget = innerNode.widgets?.find((w) => w.name === old.inputName);
@@ -897,33 +964,58 @@ export class GroupNodeHandler {
this.node.widgets[targetWidgetIndex + i].value = primitiveNode.widgets[i].value;
}
}
+ return true;
}
+ populateReroute(node, nodeId, map) {
+ if (node.type !== "Reroute") return;
+
+ const link = this.groupData.linksFrom[nodeId]?.[0]?.[0];
+ if (!link) return;
+ const [, , targetNodeId, targetNodeSlot] = link;
+ const targetNode = this.groupData.nodeData.nodes[targetNodeId];
+ const inputs = targetNode.inputs;
+ const targetWidget = inputs?.[targetNodeSlot].widget;
+ if (!targetWidget) return;
+
+ const offset = inputs.length - (targetNode.widgets_values?.length ?? 0);
+ const v = targetNode.widgets_values?.[targetNodeSlot - offset];
+ if (v == null) return;
+
+ const widgetName = Object.values(map)[0];
+ const widget = this.node.widgets.find(w => w.name === widgetName);
+ if(widget) {
+ widget.value = v;
+ }
+ }
+
+
populateWidgets() {
+ if (!this.node.widgets) return;
+
for (let nodeId = 0; nodeId < this.groupData.nodeData.nodes.length; nodeId++) {
const node = this.groupData.nodeData.nodes[nodeId];
-
- if (!node.widgets_values?.length) continue;
-
- const map = this.groupData.oldToNewWidgetMap[nodeId];
+ const map = this.groupData.oldToNewWidgetMap[nodeId] ?? {};
const widgets = Object.keys(map);
+ if (!node.widgets_values?.length) {
+ // special handling for populating values into reroutes
+ // this allows primitives connect to them to pick up the correct value
+ this.populateReroute(node, nodeId, map);
+ continue;
+ }
+
let linkedShift = 0;
for (let i = 0; i < widgets.length; i++) {
const oldName = widgets[i];
const newName = map[oldName];
const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName);
const mainWidget = this.node.widgets[widgetIndex];
- if (!newName) {
- // New name will be null if its a converted widget
- this.populatePrimitive(node, nodeId, oldName, i, linkedShift);
-
+ if (this.populatePrimitive(node, nodeId, oldName, i, linkedShift) || widgetIndex === -1) {
// Find the inner widget and shift by the number of linked widgets as they will have been removed too
const innerWidget = this.innerNodes[nodeId].widgets?.find((w) => w.name === oldName);
- linkedShift += innerWidget.linkedWidgets?.length ?? 0;
- continue;
+ linkedShift += innerWidget?.linkedWidgets?.length ?? 0;
}
-
if (widgetIndex === -1) {
continue;
}
diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js
index 8ace79562..bb2f16d42 100644
--- a/web/extensions/core/maskeditor.js
+++ b/web/extensions/core/maskeditor.js
@@ -33,6 +33,18 @@ function loadedImageToBlob(image) {
return blob;
}
+function loadImage(imagePath) {
+ return new Promise((resolve, reject) => {
+ const image = new Image();
+
+ image.onload = function() {
+ resolve(image);
+ };
+
+ image.src = imagePath;
+ });
+}
+
async function uploadMask(filepath, formData) {
await api.fetchApi('/upload/mask', {
method: 'POST',
@@ -50,25 +62,25 @@ async function uploadMask(filepath, formData) {
ClipspaceDialog.invalidatePreview();
}
-function prepareRGB(image, backupCanvas, backupCtx) {
+function prepare_mask(image, maskCanvas, maskCtx) {
// paste mask data into alpha channel
- backupCtx.drawImage(image, 0, 0, backupCanvas.width, backupCanvas.height);
- const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height);
+ maskCtx.drawImage(image, 0, 0, maskCanvas.width, maskCanvas.height);
+ const maskData = maskCtx.getImageData(0, 0, maskCanvas.width, maskCanvas.height);
- // refine mask image
- for (let i = 0; i < backupData.data.length; i += 4) {
- if(backupData.data[i+3] == 255)
- backupData.data[i+3] = 0;
+ // invert mask
+ for (let i = 0; i < maskData.data.length; i += 4) {
+ if(maskData.data[i+3] == 255)
+ maskData.data[i+3] = 0;
else
- backupData.data[i+3] = 255;
+ maskData.data[i+3] = 255;
- backupData.data[i] = 0;
- backupData.data[i+1] = 0;
- backupData.data[i+2] = 0;
+ maskData.data[i] = 0;
+ maskData.data[i+1] = 0;
+ maskData.data[i+2] = 0;
}
- backupCtx.globalCompositeOperation = 'source-over';
- backupCtx.putImageData(backupData, 0, 0);
+ maskCtx.globalCompositeOperation = 'source-over';
+ maskCtx.putImageData(maskData, 0, 0);
}
class MaskEditorDialog extends ComfyDialog {
@@ -155,10 +167,6 @@ class MaskEditorDialog extends ComfyDialog {
// If it is specified as relative, using it only as a hidden placeholder for padding is recommended
// to prevent anomalies where it exceeds a certain size and goes outside of the window.
- var placeholder = document.createElement("div");
- placeholder.style.position = "relative";
- placeholder.style.height = "50px";
-
var bottom_panel = document.createElement("div");
bottom_panel.style.position = "absolute";
bottom_panel.style.bottom = "0px";
@@ -180,18 +188,16 @@ class MaskEditorDialog extends ComfyDialog {
this.brush = brush;
this.element.appendChild(imgCanvas);
this.element.appendChild(maskCanvas);
- this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
this.element.appendChild(bottom_panel);
document.body.appendChild(brush);
- var brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => {
+ this.brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => {
self.brush_size = event.target.value;
self.updateBrushPreview(self, null, null);
});
var clearButton = this.createLeftButton("Clear",
() => {
self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height);
- self.backupCtx.clearRect(0, 0, self.backupCanvas.width, self.backupCanvas.height);
});
var cancelButton = this.createRightButton("Cancel", () => {
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
@@ -207,40 +213,42 @@ class MaskEditorDialog extends ComfyDialog {
this.element.appendChild(imgCanvas);
this.element.appendChild(maskCanvas);
- this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
this.element.appendChild(bottom_panel);
bottom_panel.appendChild(clearButton);
bottom_panel.appendChild(this.saveButton);
bottom_panel.appendChild(cancelButton);
- bottom_panel.appendChild(brush_size_slider);
+ bottom_panel.appendChild(this.brush_size_slider);
+
+ imgCanvas.style.position = "absolute";
+ maskCanvas.style.position = "absolute";
- imgCanvas.style.position = "relative";
imgCanvas.style.top = "200";
imgCanvas.style.left = "0";
- maskCanvas.style.position = "absolute";
+ maskCanvas.style.top = imgCanvas.style.top;
+ maskCanvas.style.left = imgCanvas.style.left;
}
- show() {
+ async show() {
+ this.zoom_ratio = 1.0;
+ this.pan_x = 0;
+ this.pan_y = 0;
+
if(!this.is_layout_created) {
// layout
const imgCanvas = document.createElement('canvas');
const maskCanvas = document.createElement('canvas');
- const backupCanvas = document.createElement('canvas');
imgCanvas.id = "imageCanvas";
maskCanvas.id = "maskCanvas";
- backupCanvas.id = "backupCanvas";
this.setlayout(imgCanvas, maskCanvas);
// prepare content
this.imgCanvas = imgCanvas;
this.maskCanvas = maskCanvas;
- this.backupCanvas = backupCanvas;
- this.maskCtx = maskCanvas.getContext('2d');
- this.backupCtx = backupCanvas.getContext('2d');
+ this.maskCtx = maskCanvas.getContext('2d', {willReadFrequently: true });
this.setEventHandler(maskCanvas);
@@ -252,6 +260,8 @@ class MaskEditorDialog extends ComfyDialog {
mutations.forEach(function(mutation) {
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') {
+ document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
+ self.brush.style.display = "none";
ComfyApp.onClipspaceEditorClosed();
}
@@ -264,7 +274,8 @@ class MaskEditorDialog extends ComfyDialog {
observer.observe(this.element, config);
}
- this.setImages(this.imgCanvas, this.backupCanvas);
+ // The keydown event needs to be reconfigured when closing the dialog as it gets removed.
+ document.addEventListener('keydown', MaskEditorDialog.handleKeyDown);
if(ComfyApp.clipspace_return_node) {
this.saveButton.innerText = "Save to node";
@@ -275,97 +286,157 @@ class MaskEditorDialog extends ComfyDialog {
this.saveButton.disabled = false;
this.element.style.display = "block";
+ this.element.style.width = "85%";
+ this.element.style.margin = "0 7.5%";
+ this.element.style.height = "100vh";
+ this.element.style.top = "50%";
+ this.element.style.left = "42%";
this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority.
+
+ await this.setImages(this.imgCanvas);
+
+ this.is_visible = true;
}
isOpened() {
return this.element.style.display == "block";
}
- setImages(imgCanvas, backupCanvas) {
- const imgCtx = imgCanvas.getContext('2d');
- const backupCtx = backupCanvas.getContext('2d');
+ invalidateCanvas(orig_image, mask_image) {
+ this.imgCanvas.width = orig_image.width;
+ this.imgCanvas.height = orig_image.height;
+
+ this.maskCanvas.width = orig_image.width;
+ this.maskCanvas.height = orig_image.height;
+
+ let imgCtx = this.imgCanvas.getContext('2d', {willReadFrequently: true });
+ let maskCtx = this.maskCanvas.getContext('2d', {willReadFrequently: true });
+
+ imgCtx.drawImage(orig_image, 0, 0, orig_image.width, orig_image.height);
+ prepare_mask(mask_image, this.maskCanvas, maskCtx);
+ }
+
+ async setImages(imgCanvas) {
+ let self = this;
+
+ const imgCtx = imgCanvas.getContext('2d', {willReadFrequently: true });
const maskCtx = this.maskCtx;
const maskCanvas = this.maskCanvas;
- backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height);
maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height);
// image load
- const orig_image = new Image();
- window.addEventListener("resize", () => {
- // repositioning
- imgCanvas.width = window.innerWidth - 250;
- imgCanvas.height = window.innerHeight - 200;
-
- // redraw image
- let drawWidth = orig_image.width;
- let drawHeight = orig_image.height;
- if (orig_image.width > imgCanvas.width) {
- drawWidth = imgCanvas.width;
- drawHeight = (drawWidth / orig_image.width) * orig_image.height;
- }
-
- if (drawHeight > imgCanvas.height) {
- drawHeight = imgCanvas.height;
- drawWidth = (drawHeight / orig_image.height) * orig_image.width;
- }
-
- imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight);
-
- // update mask
- maskCanvas.width = drawWidth;
- maskCanvas.height = drawHeight;
- maskCanvas.style.top = imgCanvas.offsetTop + "px";
- maskCanvas.style.left = imgCanvas.offsetLeft + "px";
- backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
- maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height);
- });
-
const filepath = ComfyApp.clipspace.images;
- const touched_image = new Image();
-
- touched_image.onload = function() {
- backupCanvas.width = touched_image.width;
- backupCanvas.height = touched_image.height;
-
- prepareRGB(touched_image, backupCanvas, backupCtx);
- };
-
const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src)
alpha_url.searchParams.delete('channel');
alpha_url.searchParams.delete('preview');
alpha_url.searchParams.set('channel', 'a');
- touched_image.src = alpha_url;
+ let mask_image = await loadImage(alpha_url);
// original image load
- orig_image.onload = function() {
- window.dispatchEvent(new Event('resize'));
- };
-
const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src);
rgb_url.searchParams.delete('channel');
rgb_url.searchParams.set('channel', 'rgb');
- orig_image.src = rgb_url;
- this.image = orig_image;
+ this.image = new Image();
+ this.image.onload = function() {
+ maskCanvas.width = self.image.width;
+ maskCanvas.height = self.image.height;
+
+ self.invalidateCanvas(self.image, mask_image);
+ self.initializeCanvasPanZoom();
+ };
+ this.image.src = rgb_url;
}
- setEventHandler(maskCanvas) {
- maskCanvas.addEventListener("contextmenu", (event) => {
- event.preventDefault();
- });
+ initializeCanvasPanZoom() {
+ // set initialize
+ let drawWidth = this.image.width;
+ let drawHeight = this.image.height;
+ let width = this.element.clientWidth;
+ let height = this.element.clientHeight;
+
+ if (this.image.width > width) {
+ drawWidth = width;
+ drawHeight = (drawWidth / this.image.width) * this.image.height;
+ }
+
+ if (drawHeight > height) {
+ drawHeight = height;
+ drawWidth = (drawHeight / this.image.height) * this.image.width;
+ }
+
+ this.zoom_ratio = drawWidth/this.image.width;
+
+ const canvasX = (width - drawWidth) / 2;
+ const canvasY = (height - drawHeight) / 2;
+ this.pan_x = canvasX;
+ this.pan_y = canvasY;
+
+ this.invalidatePanZoom();
+ }
+
+
+ invalidatePanZoom() {
+ let raw_width = this.image.width * this.zoom_ratio;
+ let raw_height = this.image.height * this.zoom_ratio;
+
+ if(this.pan_x + raw_width < 10) {
+ this.pan_x = 10 - raw_width;
+ }
+
+ if(this.pan_y + raw_height < 10) {
+ this.pan_y = 10 - raw_height;
+ }
+
+ let width = `${raw_width}px`;
+ let height = `${raw_height}px`;
+
+ let left = `${this.pan_x}px`;
+ let top = `${this.pan_y}px`;
+
+ this.maskCanvas.style.width = width;
+ this.maskCanvas.style.height = height;
+ this.maskCanvas.style.left = left;
+ this.maskCanvas.style.top = top;
+
+ this.imgCanvas.style.width = width;
+ this.imgCanvas.style.height = height;
+ this.imgCanvas.style.left = left;
+ this.imgCanvas.style.top = top;
+ }
+
+
+ setEventHandler(maskCanvas) {
const self = this;
- maskCanvas.addEventListener('wheel', (event) => this.handleWheelEvent(self,event));
- maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event));
- document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp);
- maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event));
- maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event));
- maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; });
- maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; });
- document.addEventListener('keydown', MaskEditorDialog.handleKeyDown);
+
+ if(!this.handler_registered) {
+ maskCanvas.addEventListener("contextmenu", (event) => {
+ event.preventDefault();
+ });
+
+ this.element.addEventListener('wheel', (event) => this.handleWheelEvent(self,event));
+ this.element.addEventListener('pointermove', (event) => this.pointMoveEvent(self,event));
+ this.element.addEventListener('touchmove', (event) => this.pointMoveEvent(self,event));
+
+ this.element.addEventListener('dragstart', (event) => {
+ if(event.ctrlKey) {
+ event.preventDefault();
+ }
+ });
+
+ maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event));
+ maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event));
+ maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event));
+ maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; });
+ maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; });
+
+ document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp);
+
+ this.handler_registered = true;
+ }
}
brush_size = 10;
@@ -378,8 +449,10 @@ class MaskEditorDialog extends ComfyDialog {
const self = MaskEditorDialog.instance;
if (event.key === ']') {
self.brush_size = Math.min(self.brush_size+2, 100);
+ self.brush_slider_input.value = self.brush_size;
} else if (event.key === '[') {
self.brush_size = Math.max(self.brush_size-2, 1);
+ self.brush_slider_input.value = self.brush_size;
} else if(event.key === 'Enter') {
self.save();
}
@@ -389,6 +462,10 @@ class MaskEditorDialog extends ComfyDialog {
static handlePointerUp(event) {
event.preventDefault();
+
+ this.mousedown_x = null;
+ this.mousedown_y = null;
+
MaskEditorDialog.instance.drawing_mode = false;
}
@@ -398,24 +475,70 @@ class MaskEditorDialog extends ComfyDialog {
var centerX = self.cursorX;
var centerY = self.cursorY;
- brush.style.width = self.brush_size * 2 + "px";
- brush.style.height = self.brush_size * 2 + "px";
- brush.style.left = (centerX - self.brush_size) + "px";
- brush.style.top = (centerY - self.brush_size) + "px";
+ brush.style.width = self.brush_size * 2 * this.zoom_ratio + "px";
+ brush.style.height = self.brush_size * 2 * this.zoom_ratio + "px";
+ brush.style.left = (centerX - self.brush_size * this.zoom_ratio) + "px";
+ brush.style.top = (centerY - self.brush_size * this.zoom_ratio) + "px";
}
handleWheelEvent(self, event) {
- if(event.deltaY < 0)
- self.brush_size = Math.min(self.brush_size+2, 100);
- else
- self.brush_size = Math.max(self.brush_size-2, 1);
+ event.preventDefault();
- self.brush_slider_input.value = self.brush_size;
+ if(event.ctrlKey) {
+ // zoom canvas
+ if(event.deltaY < 0) {
+ this.zoom_ratio = Math.min(10.0, this.zoom_ratio+0.2);
+ }
+ else {
+ this.zoom_ratio = Math.max(0.2, this.zoom_ratio-0.2);
+ }
+
+ this.invalidatePanZoom();
+ }
+ else {
+ // adjust brush size
+ if(event.deltaY < 0)
+ this.brush_size = Math.min(this.brush_size+2, 100);
+ else
+ this.brush_size = Math.max(this.brush_size-2, 1);
+
+ this.brush_slider_input.value = this.brush_size;
+
+ this.updateBrushPreview(this);
+ }
+ }
+
+ pointMoveEvent(self, event) {
+ this.cursorX = event.pageX;
+ this.cursorY = event.pageY;
self.updateBrushPreview(self);
+
+ if(event.ctrlKey) {
+ event.preventDefault();
+ self.pan_move(self, event);
+ }
+ }
+
+ pan_move(self, event) {
+ if(event.buttons == 1) {
+ if(this.mousedown_x) {
+ let deltaX = this.mousedown_x - event.clientX;
+ let deltaY = this.mousedown_y - event.clientY;
+
+ self.pan_x = this.mousedown_pan_x - deltaX;
+ self.pan_y = this.mousedown_pan_y - deltaY;
+
+ self.invalidatePanZoom();
+ }
+ }
}
draw_move(self, event) {
+ if(event.ctrlKey) {
+ return;
+ }
+
event.preventDefault();
this.cursorX = event.pageX;
@@ -439,6 +562,9 @@ class MaskEditorDialog extends ComfyDialog {
y = event.targetTouches[0].clientY - maskRect.top;
}
+ x /= self.zoom_ratio;
+ y /= self.zoom_ratio;
+
var brush_size = this.brush_size;
if(event instanceof PointerEvent && event.pointerType == 'pen') {
brush_size *= event.pressure;
@@ -489,8 +615,8 @@ class MaskEditorDialog extends ComfyDialog {
}
else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) {
const maskRect = self.maskCanvas.getBoundingClientRect();
- const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
- const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
+ const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio;
+ const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio;
var brush_size = this.brush_size;
if(event instanceof PointerEvent && event.pointerType == 'pen') {
@@ -540,6 +666,17 @@ class MaskEditorDialog extends ComfyDialog {
}
handlePointerDown(self, event) {
+ if(event.ctrlKey) {
+ if (event.buttons == 1) {
+ this.mousedown_x = event.clientX;
+ this.mousedown_y = event.clientY;
+
+ this.mousedown_pan_x = this.pan_x;
+ this.mousedown_pan_y = this.pan_y;
+ }
+ return;
+ }
+
var brush_size = this.brush_size;
if(event instanceof PointerEvent && event.pointerType == 'pen') {
brush_size *= event.pressure;
@@ -551,8 +688,8 @@ class MaskEditorDialog extends ComfyDialog {
event.preventDefault();
const maskRect = self.maskCanvas.getBoundingClientRect();
- const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
- const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
+ const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio;
+ const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio;
self.maskCtx.beginPath();
if (event.button == 0) {
@@ -570,15 +707,18 @@ class MaskEditorDialog extends ComfyDialog {
}
async save() {
- const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true});
+ const backupCanvas = document.createElement('canvas');
+ const backupCtx = backupCanvas.getContext('2d', {willReadFrequently:true});
+ backupCanvas.width = this.image.width;
+ backupCanvas.height = this.image.height;
- backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
+ backupCtx.clearRect(0,0, backupCanvas.width, backupCanvas.height);
backupCtx.drawImage(this.maskCanvas,
0, 0, this.maskCanvas.width, this.maskCanvas.height,
- 0, 0, this.backupCanvas.width, this.backupCanvas.height);
+ 0, 0, backupCanvas.width, backupCanvas.height);
// paste mask data into alpha channel
- const backupData = backupCtx.getImageData(0, 0, this.backupCanvas.width, this.backupCanvas.height);
+ const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height);
// refine mask image
for (let i = 0; i < backupData.data.length; i += 4) {
@@ -615,7 +755,7 @@ class MaskEditorDialog extends ComfyDialog {
ComfyApp.clipspace.widgets[index].value = item;
}
- const dataURL = this.backupCanvas.toDataURL();
+ const dataURL = backupCanvas.toDataURL();
const blob = dataURLToBlob(dataURL);
let original_url = new URL(this.image.src);
diff --git a/web/extensions/core/rerouteNode.js b/web/extensions/core/rerouteNode.js
index 499a171da..4feff91e5 100644
--- a/web/extensions/core/rerouteNode.js
+++ b/web/extensions/core/rerouteNode.js
@@ -1,10 +1,11 @@
import { app } from "../../scripts/app.js";
+import { mergeIfValid, getWidgetConfig, setWidgetConfig } from "./widgetInputs.js";
// Node that allows you to redirect connections for cleaner graphs
app.registerExtension({
name: "Comfy.RerouteNode",
- registerCustomNodes() {
+ registerCustomNodes(app) {
class RerouteNode {
constructor() {
if (!this.properties) {
@@ -16,6 +17,12 @@ app.registerExtension({
this.addInput("", "*");
this.addOutput(this.properties.showOutputText ? "*" : "", "*");
+ this.onAfterGraphConfigured = function () {
+ requestAnimationFrame(() => {
+ this.onConnectionsChange(LiteGraph.INPUT, null, true, null);
+ });
+ };
+
this.onConnectionsChange = function (type, index, connected, link_info) {
this.applyOrientation();
@@ -47,6 +54,7 @@ app.registerExtension({
const linkId = currentNode.inputs[0].link;
if (linkId !== null) {
const link = app.graph.links[linkId];
+ if (!link) return;
const node = app.graph.getNodeById(link.origin_id);
const type = node.constructor.type;
if (type === "Reroute") {
@@ -54,8 +62,7 @@ app.registerExtension({
// We've found a circle
currentNode.disconnectInput(link.target_slot);
currentNode = null;
- }
- else {
+ } else {
// Move the previous node
currentNode = node;
}
@@ -94,8 +101,11 @@ app.registerExtension({
updateNodes.push(node);
} else {
// We've found an output
- const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null;
- if (inputType && nodeOutType !== inputType) {
+ const nodeOutType =
+ node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type
+ ? node.inputs[link.target_slot].type
+ : null;
+ if (inputType && inputType !== "*" && nodeOutType !== inputType) {
// The output doesnt match our input so disconnect it
node.disconnectInput(link.target_slot);
} else {
@@ -111,6 +121,9 @@ app.registerExtension({
const displayType = inputType || outputType || "*";
const color = LGraphCanvas.link_type_colors[displayType];
+ let widgetConfig;
+ let targetWidget;
+ let widgetType;
// Update the types of each node
for (const node of updateNodes) {
// If we dont have an input type we are always wildcard but we'll show the output type
@@ -125,10 +138,38 @@ app.registerExtension({
const link = app.graph.links[l];
if (link) {
link.color = color;
+
+ if (app.configuringGraph) continue;
+ const targetNode = app.graph.getNodeById(link.target_id);
+ const targetInput = targetNode.inputs?.[link.target_slot];
+ if (targetInput?.widget) {
+ const config = getWidgetConfig(targetInput);
+ if (!widgetConfig) {
+ widgetConfig = config[1] ?? {};
+ widgetType = config[0];
+ }
+ if (!targetWidget) {
+ targetWidget = targetNode.widgets?.find((w) => w.name === targetInput.widget.name);
+ }
+
+ const merged = mergeIfValid(targetInput, [config[0], widgetConfig]);
+ if (merged.customConfig) {
+ widgetConfig = merged.customConfig;
+ }
+ }
}
}
}
+ for (const node of updateNodes) {
+ if (widgetConfig && outputType) {
+ node.inputs[0].widget = { name: "value" };
+ setWidgetConfig(node.inputs[0], [widgetType ?? displayType, widgetConfig], targetWidget);
+ } else {
+ setWidgetConfig(node.inputs[0], null);
+ }
+ }
+
if (inputNode) {
const link = app.graph.links[inputNode.inputs[0].link];
if (link) {
@@ -173,8 +214,8 @@ app.registerExtension({
},
{
// naming is inverted with respect to LiteGraphNode.horizontal
- // LiteGraphNode.horizontal == true means that
- // each slot in the inputs and outputs are layed out horizontally,
+ // LiteGraphNode.horizontal == true means that
+ // each slot in the inputs and outputs are layed out horizontally,
// which is the opposite of the visual orientation of the inputs and outputs as a node
content: "Set " + (this.properties.horizontal ? "Horizontal" : "Vertical"),
callback: () => {
@@ -187,7 +228,7 @@ app.registerExtension({
applyOrientation() {
this.horizontal = this.properties.horizontal;
if (this.horizontal) {
- // we correct the input position, because LiteGraphNode.horizontal
+ // we correct the input position, because LiteGraphNode.horizontal
// doesn't account for title presence
// which reroute nodes don't have
this.inputs[0].pos = [this.size[0] / 2, 0];
diff --git a/web/extensions/core/saveImageExtraOutput.js b/web/extensions/core/saveImageExtraOutput.js
index 99e2213bf..a0506b43b 100644
--- a/web/extensions/core/saveImageExtraOutput.js
+++ b/web/extensions/core/saveImageExtraOutput.js
@@ -1,5 +1,5 @@
import { app } from "../../scripts/app.js";
-
+import { applyTextReplacements } from "../../scripts/utils.js";
// Use widget values and dates in output filenames
app.registerExtension({
@@ -7,84 +7,19 @@ app.registerExtension({
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeData.name === "SaveImage") {
const onNodeCreated = nodeType.prototype.onNodeCreated;
-
- // Simple date formatter
- const parts = {
- d: (d) => d.getDate(),
- M: (d) => d.getMonth() + 1,
- h: (d) => d.getHours(),
- m: (d) => d.getMinutes(),
- s: (d) => d.getSeconds(),
- };
- const format =
- Object.keys(parts)
- .map((k) => k + k + "?")
- .join("|") + "|yyy?y?";
-
- function formatDate(text, date) {
- return text.replace(new RegExp(format, "g"), function (text) {
- if (text === "yy") return (date.getFullYear() + "").substring(2);
- if (text === "yyyy") return date.getFullYear();
- if (text[0] in parts) {
- const p = parts[text[0]](date);
- return (p + "").padStart(text.length, "0");
- }
- return text;
- });
- }
-
- // When the SaveImage node is created we want to override the serialization of the output name widget to run our S&R
+ // When the SaveImage node is created we want to override the serialization of the output name widget to run our S&R
nodeType.prototype.onNodeCreated = function () {
const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
const widget = this.widgets.find((w) => w.name === "filename_prefix");
widget.serializeValue = () => {
- return widget.value.replace(/%([^%]+)%/g, function (match, text) {
- const split = text.split(".");
- if (split.length !== 2) {
- // Special handling for dates
- if (split[0].startsWith("date:")) {
- return formatDate(split[0].substring(5), new Date());
- }
-
- if (text !== "width" && text !== "height") {
- // Dont warn on standard replacements
- console.warn("Invalid replacement pattern", text);
- }
- return match;
- }
-
- // Find node with matching S&R property name
- let nodes = app.graph._nodes.filter((n) => n.properties?.["Node name for S&R"] === split[0]);
- // If we cant, see if there is a node with that title
- if (!nodes.length) {
- nodes = app.graph._nodes.filter((n) => n.title === split[0]);
- }
- if (!nodes.length) {
- console.warn("Unable to find node", split[0]);
- return match;
- }
-
- if (nodes.length > 1) {
- console.warn("Multiple nodes matched", split[0], "using first match");
- }
-
- const node = nodes[0];
-
- const widget = node.widgets?.find((w) => w.name === split[1]);
- if (!widget) {
- console.warn("Unable to find widget", split[1], "on node", split[0], node);
- return match;
- }
-
- return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_");
- });
+ return applyTextReplacements(app, widget.value);
};
return r;
};
} else {
- // When any other node is created add a property to alias the node
+ // When any other node is created add a property to alias the node
const onNodeCreated = nodeType.prototype.onNodeCreated;
nodeType.prototype.onNodeCreated = function () {
const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
diff --git a/web/extensions/core/undoRedo.js b/web/extensions/core/undoRedo.js
index c6613b0f0..3cb137520 100644
--- a/web/extensions/core/undoRedo.js
+++ b/web/extensions/core/undoRedo.js
@@ -71,24 +71,21 @@ function graphEqual(a, b, root = true) {
}
const undoRedo = async (e) => {
+ const updateState = async (source, target) => {
+ const prevState = source.pop();
+ if (prevState) {
+ target.push(activeState);
+ isOurLoad = true;
+ await app.loadGraphData(prevState, false);
+ activeState = prevState;
+ }
+ }
if (e.ctrlKey || e.metaKey) {
if (e.key === "y") {
- const prevState = redo.pop();
- if (prevState) {
- undo.push(activeState);
- isOurLoad = true;
- await app.loadGraphData(prevState);
- activeState = prevState;
- }
+ updateState(redo, undo);
return true;
} else if (e.key === "z") {
- const prevState = undo.pop();
- if (prevState) {
- redo.push(activeState);
- isOurLoad = true;
- await app.loadGraphData(prevState);
- activeState = prevState;
- }
+ updateState(undo, redo);
return true;
}
}
diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js
index b6fa411f7..3f1c1f8c1 100644
--- a/web/extensions/core/widgetInputs.js
+++ b/web/extensions/core/widgetInputs.js
@@ -1,10 +1,16 @@
import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js";
import { app } from "../../scripts/app.js";
+import { applyTextReplacements } from "../../scripts/utils.js";
const CONVERTED_TYPE = "converted-widget";
const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"];
const CONFIG = Symbol();
const GET_CONFIG = Symbol();
+const TARGET = Symbol(); // Used for reroutes to specify the real target widget
+
+export function getWidgetConfig(slot) {
+ return slot.widget[CONFIG] ?? slot.widget[GET_CONFIG]();
+}
function getConfig(widgetName) {
const { nodeData } = this.constructor;
@@ -100,7 +106,6 @@ function getWidgetType(config) {
return { type };
}
-
function isValidCombo(combo, obj) {
// New input isnt a combo
if (!(obj instanceof Array)) {
@@ -121,6 +126,31 @@ function isValidCombo(combo, obj) {
return true;
}
+export function setWidgetConfig(slot, config, target) {
+ if (!slot.widget) return;
+ if (config) {
+ slot.widget[GET_CONFIG] = () => config;
+ slot.widget[TARGET] = target;
+ } else {
+ delete slot.widget;
+ }
+
+ if (slot.link) {
+ const link = app.graph.links[slot.link];
+ if (link) {
+ const originNode = app.graph.getNodeById(link.origin_id);
+ if (originNode.type === "PrimitiveNode") {
+ if (config) {
+ originNode.recreateWidget();
+ } else if(!app.configuringGraph) {
+ originNode.disconnectOutput(0);
+ originNode.onLastDisconnect();
+ }
+ }
+ }
+ }
+}
+
export function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) {
if (!config1) {
config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
@@ -150,7 +180,7 @@ export function mergeIfValid(output, config2, forceUpdate, recreateWidget, confi
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
for (const k of keys.values()) {
- if (k !== "default" && k !== "forceInput" && k !== "defaultInput") {
+ if (k !== "default" && k !== "forceInput" && k !== "defaultInput" && k !== "control_after_generate" && k !== "multiline") {
let v1 = config1[1][k];
let v2 = config2[1]?.[k];
@@ -405,11 +435,16 @@ app.registerExtension({
};
},
registerCustomNodes() {
+ const replacePropertyName = "Run widget replace on values";
class PrimitiveNode {
constructor() {
this.addOutput("connect to widget input", "*");
this.serialize_widgets = true;
this.isVirtualNode = true;
+
+ if (!this.properties || !(replacePropertyName in this.properties)) {
+ this.addProperty(replacePropertyName, false, "boolean");
+ }
}
applyToGraph(extraLinks = []) {
@@ -430,18 +465,29 @@ app.registerExtension({
}
let links = [...get_links(this).map((l) => app.graph.links[l]), ...extraLinks];
+ let v = this.widgets?.[0].value;
+ if(v && this.properties[replacePropertyName]) {
+ v = applyTextReplacements(app, v);
+ }
+
// For each output link copy our value over the original widget value
for (const linkInfo of links) {
const node = this.graph.getNodeById(linkInfo.target_id);
const input = node.inputs[linkInfo.target_slot];
- const widgetName = input.widget.name;
- if (widgetName) {
- const widget = node.widgets.find((w) => w.name === widgetName);
- if (widget) {
- widget.value = this.widgets[0].value;
- if (widget.callback) {
- widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {});
- }
+ let widget;
+ if (input.widget[TARGET]) {
+ widget = input.widget[TARGET];
+ } else {
+ const widgetName = input.widget.name;
+ if (widgetName) {
+ widget = node.widgets.find((w) => w.name === widgetName);
+ }
+ }
+
+ if (widget) {
+ widget.value = v;
+ if (widget.callback) {
+ widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {});
}
}
}
@@ -494,14 +540,13 @@ app.registerExtension({
this.#mergeWidgetConfig();
if (!links?.length) {
- this.#onLastDisconnect();
+ this.onLastDisconnect();
}
}
}
onConnectOutput(slot, type, input, target_node, target_slot) {
// Fires before the link is made allowing us to reject it if it isn't valid
-
// No widget, we cant connect
if (!input.widget) {
if (!(input.type in ComfyWidgets)) return false;
@@ -519,6 +564,10 @@ app.registerExtension({
#onFirstConnection(recreating) {
// First connection can fire before the graph is ready on initial load so random things can be missing
+ if (!this.outputs[0].links) {
+ this.onLastDisconnect();
+ return;
+ }
const linkId = this.outputs[0].links[0];
const link = this.graph.links[linkId];
if (!link) return;
@@ -546,10 +595,10 @@ app.registerExtension({
this.outputs[0].name = type;
this.outputs[0].widget = widget;
- this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating);
+ this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating, widget[TARGET]);
}
- #createWidget(inputData, node, widgetName, recreating) {
+ #createWidget(inputData, node, widgetName, recreating, targetWidget) {
let type = inputData[0];
if (type instanceof Array) {
@@ -563,7 +612,9 @@ app.registerExtension({
widget = this.addWidget(type, "value", null, () => {}, {});
}
- if (node?.widgets && widget) {
+ if (targetWidget) {
+ widget.value = targetWidget.value;
+ } else if (node?.widgets && widget) {
const theirWidget = node.widgets.find((w) => w.name === widgetName);
if (theirWidget) {
widget.value = theirWidget.value;
@@ -577,11 +628,19 @@ app.registerExtension({
}
addValueControlWidgets(this, widget, control_value, undefined, inputData);
let filter = this.widgets_values?.[2];
- if(filter && this.widgets.length === 3) {
+ if (filter && this.widgets.length === 3) {
this.widgets[2].value = filter;
}
}
+ // Restore any saved control values
+ const controlValues = this.controlValues;
+ if(this.lastType === this.widgets[0].type && controlValues?.length === this.widgets.length - 1) {
+ for(let i = 0; i < controlValues.length; i++) {
+ this.widgets[i + 1].value = controlValues[i];
+ }
+ }
+
// When our value changes, update other widgets to reflect our changes
// e.g. so LoadImage shows correct image
const callback = widget.callback;
@@ -610,12 +669,14 @@ app.registerExtension({
}
}
- #recreateWidget() {
- const values = this.widgets.map((w) => w.value);
+ recreateWidget() {
+ const values = this.widgets?.map((w) => w.value);
this.#removeWidgets();
this.#onFirstConnection(true);
- for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i];
- return this.widgets[0];
+ if (values?.length) {
+ for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i];
+ }
+ return this.widgets?.[0];
}
#mergeWidgetConfig() {
@@ -631,7 +692,7 @@ app.registerExtension({
if (links?.length < 2 && hasConfig) {
// Copy the widget options from the source
if (links.length) {
- this.#recreateWidget();
+ this.recreateWidget();
}
return;
@@ -657,7 +718,7 @@ app.registerExtension({
// Only allow connections where the configs match
const output = this.outputs[0];
const config2 = input.widget[GET_CONFIG]();
- return !!mergeIfValid.call(this, output, config2, forceUpdate, this.#recreateWidget);
+ return !!mergeIfValid.call(this, output, config2, forceUpdate, this.recreateWidget);
}
#removeWidgets() {
@@ -668,11 +729,20 @@ app.registerExtension({
w.onRemove();
}
}
+
+ // Temporarily store the current values in case the node is being recreated
+ // e.g. by group node conversion
+ this.controlValues = [];
+ this.lastType = this.widgets[0]?.type;
+ for(let i = 1; i < this.widgets.length; i++) {
+ this.controlValues.push(this.widgets[i].value);
+ }
+ setTimeout(() => { delete this.lastType; delete this.controlValues }, 15);
this.widgets.length = 0;
}
}
- #onLastDisconnect() {
+ onLastDisconnect() {
// We cant remove + re-add the output here as if you drag a link over the same link
// it removes, then re-adds, causing it to break
this.outputs[0].type = "*";
diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js
index f571edb30..434c4a83b 100644
--- a/web/lib/litegraph.core.js
+++ b/web/lib/litegraph.core.js
@@ -48,7 +48,7 @@
EVENT_LINK_COLOR: "#A86",
CONNECTING_LINK_COLOR: "#AFA",
- MAX_NUMBER_OF_NODES: 1000, //avoid infinite loops
+ MAX_NUMBER_OF_NODES: 10000, //avoid infinite loops
DEFAULT_POSITION: [100, 100], //default node position
VALID_SHAPES: ["default", "box", "round", "card"], //,"circle"
@@ -3788,16 +3788,42 @@
/**
* returns the bounding of the object, used for rendering purposes
- * bounding is: [topleft_cornerx, topleft_cornery, width, height]
* @method getBounding
- * @return {Float32Array[4]} the total size
+ * @param out {Float32Array[4]?} [optional] a place to store the output, to free garbage
+ * @param compute_outer {boolean?} [optional] set to true to include the shadow and connection points in the bounding calculation
+ * @return {Float32Array[4]} the bounding box in format of [topleft_cornerx, topleft_cornery, width, height]
*/
- LGraphNode.prototype.getBounding = function(out) {
+ LGraphNode.prototype.getBounding = function(out, compute_outer) {
out = out || new Float32Array(4);
- out[0] = this.pos[0] - 4;
- out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT;
- out[2] = this.flags.collapsed ? (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) : this.size[0] + 4;
- out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT;
+ const nodePos = this.pos;
+ const isCollapsed = this.flags.collapsed;
+ const nodeSize = this.size;
+
+ let left_offset = 0;
+ // 1 offset due to how nodes are rendered
+ let right_offset = 1 ;
+ let top_offset = 0;
+ let bottom_offset = 0;
+
+ if (compute_outer) {
+ // 4 offset for collapsed node connection points
+ left_offset = 4;
+ // 6 offset for right shadow and collapsed node connection points
+ right_offset = 6 + left_offset;
+ // 4 offset for collapsed nodes top connection points
+ top_offset = 4;
+ // 5 offset for bottom shadow and collapsed node connection points
+ bottom_offset = 5 + top_offset;
+ }
+
+ out[0] = nodePos[0] - left_offset;
+ out[1] = nodePos[1] - LiteGraph.NODE_TITLE_HEIGHT - top_offset;
+ out[2] = isCollapsed ?
+ (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) + right_offset :
+ nodeSize[0] + right_offset;
+ out[3] = isCollapsed ?
+ LiteGraph.NODE_TITLE_HEIGHT + bottom_offset :
+ nodeSize[1] + LiteGraph.NODE_TITLE_HEIGHT + bottom_offset;
if (this.onBounding) {
this.onBounding(out);
@@ -7674,7 +7700,7 @@ LGraphNode.prototype.executeAction = function(action)
continue;
}
- if (!overlapBounding(this.visible_area, n.getBounding(temp))) {
+ if (!overlapBounding(this.visible_area, n.getBounding(temp, true))) {
continue;
} //out of the visible area
@@ -11336,6 +11362,7 @@ LGraphNode.prototype.executeAction = function(action)
name_element.innerText = title;
var value_element = dialog.querySelector(".value");
value_element.value = value;
+ value_element.select();
var input = value_element;
input.addEventListener("keydown", function(e) {
diff --git a/web/scripts/app.js b/web/scripts/app.js
index 5faf41fb3..7353f5a3b 100644
--- a/web/scripts/app.js
+++ b/web/scripts/app.js
@@ -1559,9 +1559,12 @@ export class ComfyApp {
/**
* Populates the graph with the specified workflow data
* @param {*} graphData A serialized graph object
+ * @param { boolean } clean If the graph state, e.g. images, should be cleared
*/
- async loadGraphData(graphData) {
- this.clean();
+ async loadGraphData(graphData, clean = true) {
+ if (clean !== false) {
+ this.clean();
+ }
let reset_invalid_values = false;
if (!graphData) {
@@ -1771,15 +1774,26 @@ export class ComfyApp {
if (parent?.updateLink) {
link = parent.updateLink(link);
}
- inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
+ if (link) {
+ inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
+ }
}
}
}
- output[String(node.id)] = {
+ let node_data = {
inputs,
class_type: node.comfyClass,
};
+
+ if (this.ui.settings.getSettingValue("Comfy.DevMode")) {
+ // Ignored by the backend.
+ node_data["_meta"] = {
+ title: node.title,
+ }
+ }
+
+ output[String(node.id)] = node_data;
}
}
@@ -2006,12 +2020,8 @@ export class ComfyApp {
async refreshComboInNodes() {
const defs = await api.getNodeDefs();
- for(const nodeId in LiteGraph.registered_node_types) {
- const node = LiteGraph.registered_node_types[nodeId];
- const nodeDef = defs[nodeId];
- if(!nodeDef) continue;
-
- node.nodeData = nodeDef;
+ for (const nodeId in defs) {
+ this.registerNodeDef(nodeId, defs[nodeId]);
}
for(let nodeNum in this.graph._nodes) {
diff --git a/web/scripts/domWidget.js b/web/scripts/domWidget.js
index e919428a0..eb0742d38 100644
--- a/web/scripts/domWidget.js
+++ b/web/scripts/domWidget.js
@@ -177,6 +177,7 @@ LGraphCanvas.prototype.computeVisibleNodes = function () {
for (const w of node.widgets) {
if (w.element) {
w.element.hidden = hidden;
+ w.element.style.display = hidden ? "none" : undefined;
if (hidden) {
w.options.onHide?.(w);
}
diff --git a/web/scripts/utils.js b/web/scripts/utils.js
new file mode 100644
index 000000000..401aca9e4
--- /dev/null
+++ b/web/scripts/utils.js
@@ -0,0 +1,67 @@
+// Simple date formatter
+const parts = {
+ d: (d) => d.getDate(),
+ M: (d) => d.getMonth() + 1,
+ h: (d) => d.getHours(),
+ m: (d) => d.getMinutes(),
+ s: (d) => d.getSeconds(),
+};
+const format =
+ Object.keys(parts)
+ .map((k) => k + k + "?")
+ .join("|") + "|yyy?y?";
+
+function formatDate(text, date) {
+ return text.replace(new RegExp(format, "g"), function (text) {
+ if (text === "yy") return (date.getFullYear() + "").substring(2);
+ if (text === "yyyy") return date.getFullYear();
+ if (text[0] in parts) {
+ const p = parts[text[0]](date);
+ return (p + "").padStart(text.length, "0");
+ }
+ return text;
+ });
+}
+
+export function applyTextReplacements(app, value) {
+ return value.replace(/%([^%]+)%/g, function (match, text) {
+ const split = text.split(".");
+ if (split.length !== 2) {
+ // Special handling for dates
+ if (split[0].startsWith("date:")) {
+ return formatDate(split[0].substring(5), new Date());
+ }
+
+ if (text !== "width" && text !== "height") {
+ // Dont warn on standard replacements
+ console.warn("Invalid replacement pattern", text);
+ }
+ return match;
+ }
+
+ // Find node with matching S&R property name
+ let nodes = app.graph._nodes.filter((n) => n.properties?.["Node name for S&R"] === split[0]);
+ // If we cant, see if there is a node with that title
+ if (!nodes.length) {
+ nodes = app.graph._nodes.filter((n) => n.title === split[0]);
+ }
+ if (!nodes.length) {
+ console.warn("Unable to find node", split[0]);
+ return match;
+ }
+
+ if (nodes.length > 1) {
+ console.warn("Multiple nodes matched", split[0], "using first match");
+ }
+
+ const node = nodes[0];
+
+ const widget = node.widgets?.find((w) => w.name === split[1]);
+ if (!widget) {
+ console.warn("Unable to find widget", split[1], "on node", split[0], node);
+ return match;
+ }
+
+ return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_");
+ });
+}