Merge upstream, fix 3.12 compatibility, fix nightlies issue, fix broken node

This commit is contained in:
doctorpangloss 2024-01-03 16:00:36 -08:00
commit 369aeb598f
75 changed files with 2798 additions and 1075 deletions

View File

@ -1,3 +0,0 @@
..\python_embeded\python.exe .\update.py ..\ComfyUI\
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2
pause

View File

@ -1,2 +0,0 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --use-pytorch-cross-attention
pause

View File

@ -22,5 +22,5 @@ jobs:
run: | run: |
npm ci npm ci
npm run test:generate npm run test:generate
npm test npm test -- --verbose
working-directory: ./tests-ui working-directory: ./tests-ui

View File

@ -2,6 +2,24 @@ name: "Windows Release Nightly pytorch"
on: on:
workflow_dispatch: workflow_dispatch:
inputs:
cu:
description: 'cuda version'
required: true
type: string
default: "121"
python_minor:
description: 'python minor version'
required: true
type: string
default: "12"
python_patch:
description: 'python patch version'
required: true
type: string
default: "1"
# push: # push:
# branches: # branches:
# - master # - master
@ -20,21 +38,21 @@ jobs:
persist-credentials: false persist-credentials: false
- uses: actions/setup-python@v4 - uses: actions/setup-python@v4
with: with:
python-version: '3.11.6' python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
- shell: bash - shell: bash
run: | run: |
cd .. cd ..
cp -r ComfyUI ComfyUI_copy cp -r ComfyUI ComfyUI_copy
curl https://www.python.org/ftp/python/3.11.6/python-3.11.6-embed-amd64.zip -o python_embeded.zip curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip
unzip python_embeded.zip -d python_embeded unzip python_embeded.zip -d python_embeded
cd python_embeded cd python_embeded
echo 'import site' >> ./python311._pth echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py ./python.exe get-pip.py
python -m pip wheel torch torchvision torchaudio aiohttp==3.8.5 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
ls ../temp_wheel_dir ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/* ./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python311._pth sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
cd .. cd ..
git clone https://github.com/comfyanonymous/taesd git clone https://github.com/comfyanonymous/taesd
@ -49,9 +67,10 @@ jobs:
mkdir update mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/ cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./ cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/nightly/update_windows/* ./update/
cp -r ComfyUI/.ci/nightly/windows_base_files/* ./
echo "..\python_embeded\python.exe .\update.py ..\ComfyUI\\
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
pause" > ./update/update_comfyui_and_python_dependencies.bat
cd .. cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch

View File

@ -1,9 +0,0 @@
{
"path-intellisense.mappings": {
"../": "${workspaceFolder}/web/extensions/core"
},
"[python]": {
"editor.defaultFormatter": "ms-python.autopep8"
},
"python.formatting.provider": "none"
}

0
__init__.py Normal file
View File

View File

@ -53,7 +53,7 @@ class ControlNet(nn.Module):
transformer_depth_middle=None, transformer_depth_middle=None,
transformer_depth_output=None, transformer_depth_output=None,
device=None, device=None,
operations=ops, operations=ops.disable_weight_init,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
@ -141,24 +141,24 @@ class ControlNet(nn.Module):
) )
] ]
) )
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)]) self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
self.input_hint_block = TimestepEmbedSequential( self.input_hint_block = TimestepEmbedSequential(
operations.conv_nd(dims, hint_channels, 16, 3, padding=1), operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.conv_nd(dims, 16, 16, 3, padding=1), operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2), operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.conv_nd(dims, 32, 32, 3, padding=1), operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2), operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.conv_nd(dims, 96, 96, 3, padding=1), operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2), operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1)) operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
) )
self._feature_size = model_channels self._feature_size = model_channels
@ -206,7 +206,7 @@ class ControlNet(nn.Module):
) )
) )
self.input_blocks.append(TimestepEmbedSequential(*layers)) self.input_blocks.append(TimestepEmbedSequential(*layers))
self.zero_convs.append(self.make_zero_conv(ch, operations=operations)) self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
self._feature_size += ch self._feature_size += ch
input_block_chans.append(ch) input_block_chans.append(ch)
if level != len(channel_mult) - 1: if level != len(channel_mult) - 1:
@ -234,7 +234,7 @@ class ControlNet(nn.Module):
) )
ch = out_ch ch = out_ch
input_block_chans.append(ch) input_block_chans.append(ch)
self.zero_convs.append(self.make_zero_conv(ch, operations=operations)) self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
ds *= 2 ds *= 2
self._feature_size += ch self._feature_size += ch
@ -276,14 +276,14 @@ class ControlNet(nn.Module):
operations=operations operations=operations
)] )]
self.middle_block = TimestepEmbedSequential(*mid_block) self.middle_block = TimestepEmbedSequential(*mid_block)
self.middle_block_out = self.make_zero_conv(ch, operations=operations) self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
self._feature_size += ch self._feature_size += ch
def make_zero_conv(self, channels, operations=None): def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0))) return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
def forward(self, x, hint, timesteps, context, y=None, **kwargs): def forward(self, x, hint, timesteps, context, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
guided_hint = self.input_hint_block(hint, emb, context) guided_hint = self.input_hint_block(hint, emb, context)
@ -295,7 +295,7 @@ class ControlNet(nn.Module):
assert y.shape[0] == x.shape[0] assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y) emb = emb + self.label_emb(y)
h = x.type(self.dtype) h = x
for module, zero_conv in zip(self.input_blocks, self.zero_convs): for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None: if guided_hint is not None:
h = module(h, emb, context) h = module(h, emb, context)

View File

@ -55,13 +55,19 @@ fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") fpunet_group = parser.add_mutually_exclusive_group()
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
fpvae_group = parser.add_mutually_exclusive_group() fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.") fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
fpte_group = parser.add_mutually_exclusive_group() fpte_group = parser.add_mutually_exclusive_group()
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).") fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).") fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
@ -98,7 +104,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")

188
comfy/clip_model.py Normal file
View File

@ -0,0 +1,188 @@
import torch
from comfy.ldm.modules.attention import optimized_attention_for_device
class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations):
super().__init__()
self.heads = heads
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x, mask=None, optimized_attention=None):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
out = optimized_attention(q, k, v, self.heads, mask)
return self.out_proj(out)
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
"gelu": torch.nn.functional.gelu,
}
class CLIPMLP(torch.nn.Module):
def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
super().__init__()
self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
self.activation = ACTIVATIONS[activation]
self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x
class CLIPLayer(torch.nn.Module):
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
super().__init__()
self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
def forward(self, x, mask=None, optimized_attention=None):
x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
x += self.mlp(self.layer_norm2(x))
return x
class CLIPEncoder(torch.nn.Module):
def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
super().__init__()
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None)
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
for i, l in enumerate(self.layers):
x = l(x, mask, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class CLIPEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
super().__init__()
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens):
return self.token_embedding(input_tokens) + self.position_embedding.weight
class CLIPTextModel_(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
num_layers = config_dict["num_hidden_layers"]
embed_dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
x = self.embeddings(input_tokens)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
if mask is not None:
mask += causal_mask
else:
mask = causal_mask
x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
x = self.final_layer_norm(x)
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)
pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
return x, i, pooled_output
class CLIPTextModel(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.num_layers = config_dict["num_hidden_layers"]
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
self.dtype = dtype
def get_input_embeddings(self):
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, embeddings):
self.text_model.embeddings.token_embedding = embeddings
def forward(self, *args, **kwargs):
return self.text_model(*args, **kwargs)
class CLIPVisionEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
super().__init__()
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
self.patch_embedding = operations.Conv2d(
in_channels=num_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=False,
dtype=dtype,
device=device
)
num_patches = (image_size // patch_size) ** 2
num_positions = num_patches + 1
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, pixel_values):
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device)
class CLIPVision(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
num_layers = config_dict["num_hidden_layers"]
embed_dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=torch.float32, device=device, operations=operations)
self.pre_layrnorm = operations.LayerNorm(embed_dim)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.post_layernorm = operations.LayerNorm(embed_dim)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
x = self.embeddings(pixel_values)
x = self.pre_layrnorm(x)
#TODO: attention_mask?
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output
class CLIPVisionModelProjection(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
def forward(self, *args, **kwargs):
x = self.vision_model(*args, **kwargs)
out = self.visual_projection(x[2])
return (x[0], x[1], out)

View File

@ -1,36 +1,43 @@
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils
from .utils import load_torch_file, transformers_convert from .utils import load_torch_file, transformers_convert
import os import os
import torch import torch
import contextlib import json
from . import ops from . import ops
from . import model_patcher from . import model_patcher
from . import model_management from . import model_management
from . import clip_model
class Output:
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, item):
setattr(self, key, item)
def clip_preprocess(image, size=224): def clip_preprocess(image, size=224):
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype) mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype) std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
scale = (size / min(image.shape[1], image.shape[2])) image = image.movedim(-1, 1)
image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True) if not (image.shape[2] == size and image.shape[3] == size):
h = (image.shape[2] - size)//2 scale = (size / min(image.shape[2], image.shape[3]))
w = (image.shape[3] - size)//2 image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
image = image[:,:,h:h+size,w:w+size] h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0 image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1]) return (image - mean.view([3,1,1])) / std.view([3,1,1])
class ClipVisionModel(): class ClipVisionModel():
def __init__(self, json_config): def __init__(self, json_config):
config = CLIPVisionConfig.from_json_file(json_config) with open(json_config) as f:
config = json.load(f)
self.load_device = model_management.text_encoder_device() self.load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device() offload_device = model_management.text_encoder_offload_device()
self.dtype = torch.float32 self.dtype = model_management.text_encoder_dtype(self.load_device)
if model_management.should_use_fp16(self.load_device, prioritize_performance=False): self.model = clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, ops.manual_cast)
self.dtype = torch.float16 self.model.eval()
with ops.use_comfy_ops(offload_device, self.dtype):
with modeling_utils.no_init_weights():
self.model = CLIPVisionModelWithProjection(config)
self.model.to(self.dtype)
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd): def load_sd(self, sd):
@ -38,25 +45,13 @@ class ClipVisionModel():
def encode_image(self, image): def encode_image(self, image):
model_management.load_model_gpu(self.patcher) model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device)) pixel_values = clip_preprocess(image.to(self.load_device)).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
if self.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)
with precision_scope(model_management.get_autocast_device(self.load_device), torch.float32):
outputs = self.model(pixel_values=pixel_values, output_hidden_states=True)
for k in outputs:
t = outputs[k]
if t is not None:
if k == 'hidden_states':
outputs["penultimate_hidden_states"] = t[-2].cpu()
outputs["hidden_states"] = None
else:
outputs[k] = t.cpu()
outputs = Output()
outputs["last_hidden_state"] = out[0].to(model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(model_management.intermediate_device())
outputs["penultimate_hidden_states"] = out[1].to(model_management.intermediate_device())
return outputs return outputs
def convert_to_transformers(sd, prefix): def convert_to_transformers(sd, prefix):
@ -86,6 +81,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
if convert_keys: if convert_keys:
sd = convert_to_transformers(sd, prefix) sd = convert_to_transformers(sd, prefix)
if "vision_model.encoder.layers.47.layer_norm1.weight" in sd: if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
# todo: fix the importlib issue here
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")

View File

@ -13,6 +13,8 @@ import typing
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple from typing import Tuple
import sys import sys
import gc
import inspect
import torch import torch
@ -442,6 +444,8 @@ class PromptExecutor:
for x in executed: for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x]) self.old_prompt[x] = copy.deepcopy(prompt[x])
self.server.last_node_id = None self.server.last_node_id = None
if model_management.DISABLE_SMART_MEMORY:
model_management.unload_all_models()
@ -462,6 +466,14 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
errors = [] errors = []
valid = True valid = True
# todo: investigate if these are at the right indent level
info = None
val = None
validate_function_inputs = []
if hasattr(obj_class, "VALIDATE_INPUTS"):
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
for x in required_inputs: for x in required_inputs:
if x not in inputs: if x not in inputs:
error = { error = {
@ -591,29 +603,7 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
errors.append(error) errors.append(error)
continue continue
if hasattr(obj_class, "VALIDATE_INPUTS"): if x not in validate_function_inputs:
input_data_all = get_input_data(inputs, obj_class, unique_id)
# ret = obj_class.VALIDATE_INPUTS(**input_data_all)
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
for i, r3 in enumerate(ret):
if r3 is not True:
details = f"{x}"
if r3 is not False:
details += f" - {str(r3)}"
error = {
"type": "custom_validation_failed",
"message": "Custom validation failed for node",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
else:
if isinstance(type_input, list): if isinstance(type_input, list):
if val not in type_input: if val not in type_input:
input_config = info input_config = info
@ -640,6 +630,35 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
errors.append(error) errors.append(error)
continue continue
if len(validate_function_inputs) > 0:
input_data_all = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs:
input_filtered[x] = input_data_all[x]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
for x in input_filtered:
for i, r in enumerate(ret):
if r is not True:
details = f"{x}"
if r is not False:
details += f" - {str(r)}"
error = {
"type": "custom_validation_failed",
"message": "Custom validation failed for node",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if len(errors) > 0 or valid is not True: if len(errors) > 0 or valid is not True:
ret = (False, errors, unique_id) ret = (False, errors, unique_id)
else: else:
@ -771,7 +790,7 @@ class PromptQueue:
self.server.queue_updated() self.server.queue_updated()
self.not_empty.notify() self.not_empty.notify()
def get(self, timeout=None) -> typing.Tuple[QueueTuple, int]: def get(self, timeout=None) -> typing.Optional[typing.Tuple[QueueTuple, int]]:
with self.not_empty: with self.not_empty:
while len(self.queue) == 0: while len(self.queue) == 0:
self.not_empty.wait(timeout=timeout) self.not_empty.wait(timeout=timeout)

View File

@ -188,8 +188,7 @@ def cached_filename_list_(folder_name):
if folder_name not in filename_list_cache: if folder_name not in filename_list_cache:
return None return None
out = filename_list_cache[folder_name] out = filename_list_cache[folder_name]
if time.perf_counter() < (out[2] + 0.5):
return out
for x in out[1]: for x in out[1]:
time_modified = out[1][x] time_modified = out[1][x]
folder = x folder = x

View File

@ -23,9 +23,9 @@ def execute_prestartup_script():
return False return False
node_paths = folder_paths.get_folder_paths("custom_nodes") node_paths = folder_paths.get_folder_paths("custom_nodes")
node_prestartup_times = []
for custom_node_path in node_paths: for custom_node_path in node_paths:
possible_modules = os.listdir(custom_node_path) if os.path.exists(custom_node_path) else [] possible_modules = os.listdir(custom_node_path) if os.path.exists(custom_node_path) else []
node_prestartup_times = []
for possible_module in possible_modules: for possible_module in possible_modules:
module_path = os.path.join(custom_node_path, possible_module) module_path = os.path.join(custom_node_path, possible_module)
@ -69,6 +69,10 @@ if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
print("Set cuda device to:", args.cuda_device) print("Set cuda device to:", args.cuda_device)
if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
from .. import utils from .. import utils
import yaml import yaml
@ -78,12 +82,12 @@ from .server import BinaryEventTypes
from .. import model_management from .. import model_management
def prompt_worker(q: execution.PromptQueue, _server: server_module.PromptServer): def prompt_worker(q, _server):
e = execution.PromptExecutor(_server) e = execution.PromptExecutor(_server)
last_gc_collect = 0 last_gc_collect = 0
need_gc = False need_gc = False
gc_collect_interval = 10.0 gc_collect_interval = 10.0
current_time = 0.0
while True: while True:
timeout = None timeout = None
if need_gc: if need_gc:
@ -94,11 +98,13 @@ def prompt_worker(q: execution.PromptQueue, _server: server_module.PromptServer)
item, item_id = queue_item item, item_id = queue_item
execution_start_time = time.perf_counter() execution_start_time = time.perf_counter()
prompt_id = item[1] prompt_id = item[1]
_server.last_prompt_id = prompt_id
e.execute(item[2], prompt_id, item[3], item[4]) e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True need_gc = True
q.task_done(item_id, e.outputs_ui) q.task_done(item_id, e.outputs_ui)
if _server.client_id is not None: if _server.client_id is not None:
_server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, _server.client_id) _server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, _server.client_id)
current_time = time.perf_counter() current_time = time.perf_counter()
execution_time = current_time - execution_start_time execution_time = current_time - execution_start_time
@ -119,7 +125,10 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
def hijack_progress(server): def hijack_progress(server):
def hook(value, total, preview_image): def hook(value, total, preview_image):
server.send_sync("progress", {"value": value, "max": total}, server.client_id) model_management.throw_exception_if_processing_interrupted()
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
server.send_sync("progress", progress, server.client_id)
if preview_image is not None: if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
@ -204,7 +213,7 @@ def main():
print(f"Setting output directory to: {output_dir}") print(f"Setting output directory to: {output_dir}")
folder_paths.set_output_directory(output_dir) folder_paths.set_output_directory(output_dir)
#These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes # These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints")) folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip")) folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae")) folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))

View File

@ -734,7 +734,8 @@ class PromptServer():
message = self.encode_bytes(event, data) message = self.encode_bytes(event, data)
if sid is None: if sid is None:
for ws in self.sockets.values(): sockets = list(self.sockets.values())
for ws in sockets:
await send_socket_catch_exception(ws.send_bytes, message) await send_socket_catch_exception(ws.send_bytes, message)
elif sid in self.sockets: elif sid in self.sockets:
await send_socket_catch_exception(self.sockets[sid].send_bytes, message) await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
@ -743,7 +744,8 @@ class PromptServer():
message = {"type": event, "data": data} message = {"type": event, "data": data}
if sid is None: if sid is None:
for ws in self.sockets.values(): sockets = list(self.sockets.values())
for ws in sockets:
await send_socket_catch_exception(ws.send_json, message) await send_socket_catch_exception(ws.send_json, message)
elif sid in self.sockets: elif sid in self.sockets:
await send_socket_catch_exception(self.sockets[sid].send_json, message) await send_socket_catch_exception(self.sockets[sid].send_json, message)

View File

@ -1,10 +1,13 @@
import torch import torch
import math import math
import os import os
import contextlib
from . import utils from . import utils
from . import model_management from . import model_management
from . import model_detection from . import model_detection
from . import model_patcher from . import model_patcher
from . import ops
from .cldm import cldm from .cldm import cldm
from .t2i_adapter import adapter from .t2i_adapter import adapter
@ -34,13 +37,13 @@ class ControlBase:
self.cond_hint = None self.cond_hint = None
self.strength = 1.0 self.strength = 1.0
self.timestep_percent_range = (0.0, 1.0) self.timestep_percent_range = (0.0, 1.0)
self.global_average_pooling = False
self.timestep_range = None self.timestep_range = None
if device is None: if device is None:
device = model_management.get_torch_device() device = model_management.get_torch_device()
self.device = device self.device = device
self.previous_controlnet = None self.previous_controlnet = None
self.global_average_pooling = False
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)): def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
self.cond_hint_original = cond_hint self.cond_hint_original = cond_hint
@ -75,6 +78,7 @@ class ControlBase:
c.cond_hint_original = self.cond_hint_original c.cond_hint_original = self.cond_hint_original
c.strength = self.strength c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range c.timestep_percent_range = self.timestep_percent_range
c.global_average_pooling = self.global_average_pooling
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
@ -127,12 +131,14 @@ class ControlBase:
return out return out
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None): def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) self.load_device = load_device
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
self.model_sampling_current = None self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None control_prev = None
@ -146,28 +152,31 @@ class ControlNet(ControlBase):
else: else:
return None return None
dtype = self.control_model.dtype
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
output_dtype = x_noisy.dtype output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None: if self.cond_hint is not None:
del self.cond_hint del self.cond_hint
self.cond_hint = None self.cond_hint = None
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
if x_noisy.shape[0] != self.cond_hint.shape[0]: if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
context = cond['c_crossattn'] context = cond['c_crossattn']
y = cond.get('y', None) y = cond.get('y', None)
if y is not None: if y is not None:
y = y.to(self.control_model.dtype) y = y.to(dtype)
timestep = self.model_sampling_current.timestep(t) timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y) control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype) return self.control_merge(None, control, control_prev, output_dtype)
def copy(self): def copy(self):
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
self.copy_to(c) self.copy_to(c)
return c return c
@ -198,10 +207,11 @@ class ControlLoraOps:
self.bias = None self.bias = None
def forward(self, input): def forward(self, input):
weight, bias = ops.cast_bias_weight(self, input)
if self.up is not None: if self.up is not None:
return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else: else:
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias) return torch.nn.functional.linear(input, weight, bias)
class Conv2d(torch.nn.Module): class Conv2d(torch.nn.Module):
def __init__( def __init__(
@ -237,16 +247,11 @@ class ControlLoraOps:
def forward(self, input): def forward(self, input):
weight, bias = ops.cast_bias_weight(self, input)
if self.up is not None: if self.up is not None:
return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else: else:
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
def conv_nd(self, dims, *args, **kwargs):
if dims == 2:
return self.Conv2d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
class ControlLora(ControlNet): class ControlLora(ControlNet):
@ -260,17 +265,26 @@ class ControlLora(ControlNet):
controlnet_config = model.model_config.unet_config.copy() controlnet_config = model.model_config.unet_config.copy()
controlnet_config.pop("out_channels") controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1] controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
controlnet_config["operations"] = ControlLoraOps() self.manual_cast_dtype = model.manual_cast_dtype
self.control_model = cldm.ControlNet(**controlnet_config)
dtype = model.get_dtype() dtype = model.get_dtype()
self.control_model.to(dtype) if self.manual_cast_dtype is None:
class control_lora_ops(ControlLoraOps, ops.disable_weight_init):
pass
else:
class control_lora_ops(ControlLoraOps, ops.manual_cast):
pass
dtype = self.manual_cast_dtype
controlnet_config["operations"] = control_lora_ops
controlnet_config["dtype"] = dtype
self.control_model = cldm.ControlNet(**controlnet_config)
self.control_model.to(model_management.get_torch_device()) self.control_model.to(model_management.get_torch_device())
diffusion_model = model.diffusion_model diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict() sd = diffusion_model.state_dict()
cm = self.control_model.state_dict() cm = self.control_model.state_dict()
for k in sd: for k in sd:
weight = model_management.resolve_lowvram_weight(sd[k], diffusion_model, k) weight = sd[k]
try: try:
utils.set_attr(self.control_model, k, weight) utils.set_attr(self.control_model, k, weight)
except: except:
@ -367,6 +381,10 @@ def load_controlnet(ckpt_path, model=None):
if controlnet_config is None: if controlnet_config is None:
unet_dtype = model_management.unet_dtype() unet_dtype = model_management.unet_dtype()
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
controlnet_config["operations"] = ops.manual_cast
controlnet_config.pop("out_channels") controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = cldm.ControlNet(**controlnet_config) control_model = cldm.ControlNet(**controlnet_config)
@ -395,14 +413,12 @@ def load_controlnet(ckpt_path, model=None):
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
print(missing, unexpected) print(missing, unexpected)
control_model = control_model.to(unet_dtype)
global_average_pooling = False global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0] filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling) control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control return control
class T2IAdapter(ControlBase): class T2IAdapter(ControlBase):

View File

@ -33,3 +33,7 @@ class SDXL(LatentFormat):
[-0.3112, -0.2359, -0.2076] [-0.3112, -0.2359, -0.2076]
] ]
self.taesd_decoder_name = "taesdxl_decoder" self.taesd_decoder_name = "taesdxl_decoder"
class SD_X4(LatentFormat):
def __init__(self):
self.scale_factor = 0.08333

View File

@ -9,6 +9,7 @@ from ..modules.distributions.distributions import DiagonalGaussianDistribution
from ..util import instantiate_from_config, get_obj_from_str from ..util import instantiate_from_config, get_obj_from_str
from ..modules.ema import LitEma from ..modules.ema import LitEma
from ... import ops
class DiagonalGaussianRegularizer(torch.nn.Module): class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = True): def __init__(self, sample: bool = True):
@ -162,12 +163,12 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
}, },
**kwargs, **kwargs,
) )
self.quant_conv = torch.nn.Conv2d( self.quant_conv = ops.disable_weight_init.Conv2d(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"], (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim, (1 + ddconfig["double_z"]) * embed_dim,
1, 1,
) )
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.post_quant_conv = ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim self.embed_dim = embed_dim
def get_autoencoder_params(self) -> list: def get_autoencoder_params(self) -> list:

View File

@ -18,6 +18,7 @@ if model_management.xformers_enabled():
from ...cli_args import args from ...cli_args import args
from ... import ops from ... import ops
ops = ops.disable_weight_init
# CrossAttn precision handling # CrossAttn precision handling
if args.dont_upcast_attention: if args.dont_upcast_attention:
@ -82,16 +83,6 @@ class FeedForward(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels, dtype=None, device=None): def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
@ -112,19 +103,20 @@ def attention_basic(q, k, v, heads, mask=None):
# force cast to fp32 to avoid overflowing # force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32": if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'): sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * scale
else: else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale sim = einsum('b i d, b j d -> b i j', q, k) * scale
del q, k del q, k
if exists(mask): if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)') if mask.dtype == torch.bool:
max_neg_value = -torch.finfo(sim.dtype).max mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
mask = repeat(mask, 'b j -> (b h) () j', h=h) max_neg_value = -torch.finfo(sim.dtype).max
sim.masked_fill_(~mask, max_neg_value) mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
else:
sim += mask
# attention, what we cannot get enough of # attention, what we cannot get enough of
sim = sim.softmax(dim=-1) sim = sim.softmax(dim=-1)
@ -349,6 +341,18 @@ else:
if model_management.pytorch_attention_enabled(): if model_management.pytorch_attention_enabled():
optimized_attention_masked = attention_pytorch optimized_attention_masked = attention_pytorch
def optimized_attention_for_device(device, mask=False):
if device == torch.device("cpu"): #TODO
if model_management.pytorch_attention_enabled():
return attention_pytorch
else:
return attention_basic
if mask:
return optimized_attention_masked
return optimized_attention
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
super().__init__() super().__init__()
@ -393,7 +397,7 @@ class BasicTransformerBlock(nn.Module):
self.is_res = inner_dim == dim self.is_res = inner_dim == dim
if self.ff_in: if self.ff_in:
self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device) self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.disable_self_attn = disable_self_attn self.disable_self_attn = disable_self_attn
@ -413,10 +417,10 @@ class BasicTransformerBlock(nn.Module):
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.n_heads = n_heads self.n_heads = n_heads
self.d_head = d_head self.d_head = d_head
@ -558,7 +562,7 @@ class SpatialTransformer(nn.Module):
context_dim = [context_dim] * depth context_dim = [context_dim] * depth
self.in_channels = in_channels self.in_channels = in_channels
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.norm = Normalize(in_channels, dtype=dtype, device=device) self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
if not use_linear: if not use_linear:
self.proj_in = operations.Conv2d(in_channels, self.proj_in = operations.Conv2d(in_channels,
inner_dim, inner_dim,

View File

@ -8,6 +8,7 @@ from typing import Optional, Any
from .... import model_management from .... import model_management
from .... import ops from .... import ops
ops = ops.disable_weight_init
if model_management.xformers_enabled_vae(): if model_management.xformers_enabled_vae():
import xformers import xformers
@ -40,7 +41,7 @@ def nonlinearity(x):
def Normalize(in_channels, num_groups=32): def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module): class Upsample(nn.Module):

View File

@ -12,13 +12,13 @@ from .util import (
checkpoint, checkpoint,
avg_pool_nd, avg_pool_nd,
zero_module, zero_module,
normalization,
timestep_embedding, timestep_embedding,
AlphaBlender, AlphaBlender,
) )
from ..attention import SpatialTransformer, SpatialVideoTransformer, default from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from ...util import exists from ...util import exists
from .... import ops from .... import ops
ops = ops.disable_weight_init
class TimestepBlock(nn.Module): class TimestepBlock(nn.Module):
""" """
@ -177,7 +177,7 @@ class ResBlock(TimestepBlock):
padding = kernel_size // 2 padding = kernel_size // 2
self.in_layers = nn.Sequential( self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels, dtype=dtype, device=device), operations.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device), operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
) )
@ -206,12 +206,11 @@ class ResBlock(TimestepBlock):
), ),
) )
self.out_layers = nn.Sequential( self.out_layers = nn.Sequential(
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device), operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device) ,
),
) )
if self.out_channels == channels: if self.out_channels == channels:
@ -438,9 +437,6 @@ class UNetModel(nn.Module):
operations=ops, operations=ops,
): ):
super().__init__() super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None: if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
@ -457,7 +453,6 @@ class UNetModel(nn.Module):
if num_head_channels == -1: if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.image_size = image_size
self.in_channels = in_channels self.in_channels = in_channels
self.model_channels = model_channels self.model_channels = model_channels
self.out_channels = out_channels self.out_channels = out_channels
@ -503,7 +498,7 @@ class UNetModel(nn.Module):
if self.num_classes is not None: if self.num_classes is not None:
if isinstance(self.num_classes, int): if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim) self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
elif self.num_classes == "continuous": elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer") print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim) self.label_emb = nn.Linear(1, time_embed_dim)
@ -810,13 +805,13 @@ class UNetModel(nn.Module):
self._feature_size += ch self._feature_size += ch
self.out = nn.Sequential( self.out = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device), operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)), zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
) )
if self.predict_codebook_ids: if self.predict_codebook_ids:
self.id_predictor = nn.Sequential( self.id_predictor = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device), operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device), operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
) )
@ -842,14 +837,14 @@ class UNetModel(nn.Module):
self.num_classes is not None self.num_classes is not None
), "must specify y if and only if the model is class-conditional" ), "must specify y if and only if the model is class-conditional"
hs = [] hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb) emb = self.time_embed(t_emb)
if self.num_classes is not None: if self.num_classes is not None:
assert y.shape[0] == x.shape[0] assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y) emb = emb + self.label_emb(y)
h = x.type(self.dtype) h = x
for id, module in enumerate(self.input_blocks): for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id) transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)

View File

@ -41,10 +41,14 @@ class AbstractLowScaleModel(nn.Module):
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
def q_sample(self, x_start, t, noise=None): def q_sample(self, x_start, t, noise=None, seed=None):
noise = default(noise, lambda: torch.randn_like(x_start)) if noise is None:
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + if seed is None:
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) noise = torch.randn_like(x_start)
else:
noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device)
return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise)
def forward(self, x): def forward(self, x):
return x, None return x, None
@ -69,12 +73,12 @@ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
super().__init__(noise_schedule_config=noise_schedule_config) super().__init__(noise_schedule_config=noise_schedule_config)
self.max_noise_level = max_noise_level self.max_noise_level = max_noise_level
def forward(self, x, noise_level=None): def forward(self, x, noise_level=None, seed=None):
if noise_level is None: if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else: else:
assert isinstance(noise_level, torch.Tensor) assert isinstance(noise_level, torch.Tensor)
z = self.q_sample(x, noise_level) z = self.q_sample(x, noise_level, seed=seed)
return z, noise_level return z, noise_level

View File

@ -16,7 +16,6 @@ import numpy as np
from einops import repeat, rearrange from einops import repeat, rearrange
from ...util import instantiate_from_config from ...util import instantiate_from_config
from .... import ops
class AlphaBlender(nn.Module): class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"] strategies = ["learned", "fixed", "learned_with_images"]
@ -52,9 +51,9 @@ class AlphaBlender(nn.Module):
if self.merge_strategy == "fixed": if self.merge_strategy == "fixed":
# make shape compatible # make shape compatible
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs) # alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
alpha = self.mix_factor alpha = self.mix_factor.to(image_only_indicator.device)
elif self.merge_strategy == "learned": elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor) alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device))
# make shape compatible # make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs) # alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
elif self.merge_strategy == "learned_with_images": elif self.merge_strategy == "learned_with_images":
@ -62,7 +61,7 @@ class AlphaBlender(nn.Module):
alpha = torch.where( alpha = torch.where(
image_only_indicator.bool(), image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device), torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"),
) )
alpha = rearrange(alpha, self.rearrange_pattern) alpha = rearrange(alpha, self.rearrange_pattern)
# make shape compatible # make shape compatible
@ -273,46 +272,6 @@ def mean_flat(tensor):
return tensor.mean(dim=list(range(1, len(tensor.shape)))) return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels, dtype=None):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels, dtype=dtype)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return ops.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return ops.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs): def avg_pool_nd(dims, *args, **kwargs):
""" """
Create a 1D, 2D, or 3D average pooling module. Create a 1D, 2D, or 3D average pooling module.

View File

@ -15,12 +15,12 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
def scale(self, x): def scale(self, x):
# re-normalize to centered mean and unit variance # re-normalize to centered mean and unit variance
x = (x - self.data_mean) * 1. / self.data_std x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device)
return x return x
def unscale(self, x): def unscale(self, x):
# back to original data stats # back to original data stats
x = (x * self.data_std) + self.data_mean x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
return x return x
def forward(self, x, noise_level=None): def forward(self, x, noise_level=None):

View File

@ -5,6 +5,7 @@ import torch
from einops import rearrange, repeat from einops import rearrange, repeat
from ... import ops from ... import ops
ops = ops.disable_weight_init
from .diffusionmodules.model import ( from .diffusionmodules.model import (
AttnBlock, AttnBlock,
@ -81,14 +82,14 @@ class VideoResBlock(ResnetBlock):
x = self.time_stack(x, temb) x = self.time_stack(x, temb)
alpha = self.get_alpha(bs=b // timesteps) alpha = self.get_alpha(bs=b // timesteps).to(x.device)
x = alpha * x + (1.0 - alpha) * x_mix x = alpha * x + (1.0 - alpha) * x_mix
x = rearrange(x, "b c t h w -> (b t) c h w") x = rearrange(x, "b c t h w -> (b t) c h w")
return x return x
class AE3DConv(torch.nn.Conv2d): class AE3DConv(ops.Conv2d):
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
super().__init__(in_channels, out_channels, *args, **kwargs) super().__init__(in_channels, out_channels, *args, **kwargs)
if isinstance(video_kernel_size, Iterable): if isinstance(video_kernel_size, Iterable):
@ -96,7 +97,7 @@ class AE3DConv(torch.nn.Conv2d):
else: else:
padding = int(video_kernel_size // 2) padding = int(video_kernel_size // 2)
self.time_mix_conv = torch.nn.Conv3d( self.time_mix_conv = ops.Conv3d(
in_channels=out_channels, in_channels=out_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=video_kernel_size, kernel_size=video_kernel_size,
@ -166,7 +167,7 @@ class AttnVideoBlock(AttnBlock):
emb = emb[:, None, :] emb = emb[:, None, :]
x_mix = x_mix + emb x_mix = x_mix + emb
alpha = self.get_alpha() alpha = self.get_alpha().to(x.device)
x_mix = self.time_mix_block(x_mix, timesteps=timesteps) x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge x = alpha * x + (1.0 - alpha) * x_mix # alpha merge

View File

@ -43,7 +43,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys(): if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name] mid = lora[mid_name]
loaded_keys.add(mid_name) loaded_keys.add(mid_name)
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid) patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
loaded_keys.add(A_name) loaded_keys.add(A_name)
loaded_keys.add(B_name) loaded_keys.add(B_name)
@ -64,7 +64,7 @@ def load_lora(lora, to_load):
loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name) loaded_keys.add(hada_t2_name)
patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2) patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name) loaded_keys.add(hada_w2_a_name)
@ -116,8 +116,19 @@ def load_lora(lora, to_load):
loaded_keys.add(lokr_t2_name) loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2) patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
#glora
a1_name = "{}.a1.weight".format(x)
a2_name = "{}.a2.weight".format(x)
b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x)
if a1_name in lora:
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
loaded_keys.add(a1_name)
loaded_keys.add(a2_name)
loaded_keys.add(b1_name)
loaded_keys.add(b2_name)
w_norm_name = "{}.w_norm".format(x) w_norm_name = "{}.w_norm".format(x)
b_norm_name = "{}.b_norm".format(x) b_norm_name = "{}.b_norm".format(x)
@ -126,21 +137,21 @@ def load_lora(lora, to_load):
if w_norm is not None: if w_norm is not None:
loaded_keys.add(w_norm_name) loaded_keys.add(w_norm_name)
patch_dict[to_load[x]] = (w_norm,) patch_dict[to_load[x]] = ("diff", (w_norm,))
if b_norm is not None: if b_norm is not None:
loaded_keys.add(b_norm_name) loaded_keys.add(b_norm_name)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,) patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
diff_name = "{}.diff".format(x) diff_name = "{}.diff".format(x)
diff_weight = lora.get(diff_name, None) diff_weight = lora.get(diff_name, None)
if diff_weight is not None: if diff_weight is not None:
patch_dict[to_load[x]] = (diff_weight,) patch_dict[to_load[x]] = ("diff", (diff_weight,))
loaded_keys.add(diff_name) loaded_keys.add(diff_name)
diff_bias_name = "{}.diff_b".format(x) diff_bias_name = "{}.diff_b".format(x)
diff_bias = lora.get(diff_bias_name, None) diff_bias = lora.get(diff_bias_name, None)
if diff_bias is not None: if diff_bias is not None:
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,) patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
loaded_keys.add(diff_bias_name) loaded_keys.add(diff_bias_name)
for x in lora.keys(): for x in lora.keys():

View File

@ -1,10 +1,12 @@
import torch import torch
from .ldm.modules.diffusionmodules.openaimodel import UNetModel from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from .ldm.modules.diffusionmodules.openaimodel import Timestep from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from . import model_management from . import model_management
from . import conds from . import conds
from . import ops
from enum import Enum from enum import Enum
import contextlib
from . import utils from . import utils
class ModelType(Enum): class ModelType(Enum):
@ -13,7 +15,7 @@ class ModelType(Enum):
V_PREDICTION_EDM = 3 V_PREDICTION_EDM = 3
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM from .model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
def model_sampling(model_config, model_type): def model_sampling(model_config, model_type):
@ -40,9 +42,14 @@ class BaseModel(torch.nn.Module):
unet_config = model_config.unet_config unet_config = model_config.unet_config
self.latent_format = model_config.latent_format self.latent_format = model_config.latent_format
self.model_config = model_config self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype
if not unet_config.get("disable_unet_model_creation", False): if not unet_config.get("disable_unet_model_creation", False):
self.diffusion_model = UNetModel(**unet_config, device=device) if self.manual_cast_dtype is not None:
operations = ops.manual_cast
else:
operations = ops.disable_weight_init
self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
self.model_type = model_type self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type) self.model_sampling = model_sampling(model_config, model_type)
@ -61,15 +68,21 @@ class BaseModel(torch.nn.Module):
context = c_crossattn context = c_crossattn
dtype = self.get_dtype() dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
xc = xc.to(dtype) xc = xc.to(dtype)
t = self.model_sampling.timestep(t).float() t = self.model_sampling.timestep(t).float()
context = context.to(dtype) context = context.to(dtype)
extra_conds = {} extra_conds = {}
for o in kwargs: for o in kwargs:
extra = kwargs[o] extra = kwargs[o]
if hasattr(extra, "to"): if hasattr(extra, "dtype"):
extra = extra.to(dtype) if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
extra_conds[o] = extra extra_conds[o] = extra
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x) return self.model_sampling.calculate_denoised(sigma, model_output, x)
@ -117,6 +130,10 @@ class BaseModel(torch.nn.Module):
adm = self.encode_adm(**kwargs) adm = self.encode_adm(**kwargs)
if adm is not None: if adm is not None:
out['y'] = conds.CONDRegular(adm) out['y'] = conds.CONDRegular(adm)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
return out return out
def load_model_weights(self, sd, unet_prefix=""): def load_model_weights(self, sd, unet_prefix=""):
@ -144,11 +161,7 @@ class BaseModel(torch.nn.Module):
def state_dict_for_saving(self, clip_state_dict, vae_state_dict): def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
unet_sd = self.diffusion_model.state_dict() unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = {}
for k in unet_sd:
unet_state_dict[k] = model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16: if self.get_dtype() == torch.float16:
@ -165,9 +178,12 @@ class BaseModel(torch.nn.Module):
def memory_required(self, input_shape): def memory_required(self, input_shape):
if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention(): if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked #TODO: this needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3] area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024) return (area * model_management.dtype_size(dtype) / 50) * (1024 * 1024)
else: else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * input_shape[2] * input_shape[3] area = input_shape[0] * input_shape[2] * input_shape[3]
@ -307,9 +323,75 @@ class SVD_img2vid(BaseModel):
out['c_concat'] = conds.CONDNoiseShape(latent_image) out['c_concat'] = conds.CONDNoiseShape(latent_image)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
if "time_conditioning" in kwargs: if "time_conditioning" in kwargs:
out["time_context"] = conds.CONDCrossAttn(kwargs["time_conditioning"]) out["time_context"] = conds.CONDCrossAttn(kwargs["time_conditioning"])
out['image_only_indicator'] = conds.CONDConstant(torch.zeros((1,), device=device)) out['image_only_indicator'] = conds.CONDConstant(torch.zeros((1,), device=device))
out['num_video_frames'] = conds.CONDConstant(noise.shape[0]) out['num_video_frames'] = conds.CONDConstant(noise.shape[0])
return out return out
class Stable_Zero123(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
super().__init__(model_config, model_type, device=device)
self.cc_projection = ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device)
self.cc_projection.weight.copy_(cc_projection_weight)
self.cc_projection.bias.copy_(cc_projection_bias)
def extra_conds(self, **kwargs):
out = {}
latent_image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
if latent_image is None:
latent_image = torch.zeros_like(noise)
if latent_image.shape[1:] != noise.shape[1:]:
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])
out['c_concat'] = conds.CONDNoiseShape(latent_image)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
if cross_attn.shape[-1] != 768:
cross_attn = self.cc_projection(cross_attn)
out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
return out
class SD_X4Upscaler(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device)
self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350)
def extra_conds(self, **kwargs):
out = {}
image = kwargs.get("concat_image", None)
noise = kwargs.get("noise", None)
noise_augment = kwargs.get("noise_augmentation", 0.0)
device = kwargs["device"]
seed = kwargs["seed"] - 10
noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment)
if image is None:
image = torch.zeros_like(noise)[:,:3]
if image.shape[1:] != noise.shape[1:]:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
noise_level = torch.tensor([noise_level], device=device)
if noise_augment > 0:
image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed)
image = utils.resize_to_batch_size(image, noise.shape[0])
out['c_concat'] = conds.CONDNoiseShape(image)
out['y'] = conds.CONDRegular(noise_level)
return out

View File

@ -34,7 +34,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
unet_config = { unet_config = {
"use_checkpoint": False, "use_checkpoint": False,
"image_size": 32, "image_size": 32,
"out_channels": 4,
"use_spatial_transformer": True, "use_spatial_transformer": True,
"legacy": False "legacy": False
} }
@ -50,6 +49,12 @@ def detect_unet_config(state_dict, key_prefix, dtype):
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
out_key = '{}out.2.weight'.format(key_prefix)
if out_key in state_dict:
out_channels = state_dict[out_key].shape[0]
else:
out_channels = 4
num_res_blocks = [] num_res_blocks = []
channel_mult = [] channel_mult = []
attention_resolutions = [] attention_resolutions = []
@ -122,6 +127,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
transformer_depth_middle = -1 transformer_depth_middle = -1
unet_config["in_channels"] = in_channels unet_config["in_channels"] = in_channels
unet_config["out_channels"] = out_channels
unet_config["model_channels"] = model_channels unet_config["model_channels"] = model_channels
unet_config["num_res_blocks"] = num_res_blocks unet_config["num_res_blocks"] = num_res_blocks
unet_config["transformer_depth"] = transformer_depth unet_config["transformer_depth"] = transformer_depth
@ -289,7 +295,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B] Segmind_Vega = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 1, 1, 2, 2], 'transformer_depth_output': [0, 0, 0, 1, 1, 1, 2, 2, 2],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega]
for unet_config in supported_models: for unet_config in supported_models:
matches = True matches = True

View File

@ -2,6 +2,7 @@ import psutil
from enum import Enum from enum import Enum
from .cli_args import args from .cli_args import args
from . import utils from . import utils
import torch import torch
import sys import sys
@ -28,6 +29,10 @@ total_vram = 0
lowvram_available = True lowvram_available = True
xpu_available = False xpu_available = False
if args.deterministic:
print("Using deterministic algorithms for pytorch")
torch.use_deterministic_algorithms(True, warn_only=True)
directml_enabled = False directml_enabled = False
if args.directml is not None: if args.directml is not None:
import torch_directml import torch_directml
@ -182,6 +187,9 @@ except:
if is_intel_xpu(): if is_intel_xpu():
VAE_DTYPE = torch.bfloat16 VAE_DTYPE = torch.bfloat16
if args.cpu_vae:
VAE_DTYPE = torch.float32
if args.fp16_vae: if args.fp16_vae:
VAE_DTYPE = torch.float16 VAE_DTYPE = torch.float16
elif args.bf16_vae: elif args.bf16_vae:
@ -214,15 +222,8 @@ if args.force_fp16 or cpu_state == CPUState.MPS:
FORCE_FP16 = True FORCE_FP16 = True
if lowvram_available: if lowvram_available:
try: if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
import accelerate vram_state = set_vram_to
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to
except Exception as e:
import traceback
print(traceback.format_exc())
print("ERROR: LOW VRAM MODE NEEDS accelerate.")
lowvram_available = False
if cpu_state != CPUState.GPU: if cpu_state != CPUState.GPU:
@ -262,6 +263,14 @@ print("VAE dtype:", VAE_DTYPE)
current_loaded_models = [] current_loaded_models = []
def module_size(module):
module_mem = 0
sd = module.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nelement() * t.element_size()
return module_mem
class LoadedModel: class LoadedModel:
def __init__(self, model): def __init__(self, model):
self.model = model self.model = model
@ -294,8 +303,20 @@ class LoadedModel:
if lowvram_model_memory > 0: if lowvram_model_memory > 0:
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024)) print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) mem_counter = 0
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device) for m in self.real_model.modules():
if hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
module_mem = module_size(m)
if mem_counter + module_mem < lowvram_model_memory:
m.to(self.device)
mem_counter += module_mem
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
m.to(self.device)
mem_counter += module_size(m)
print("lowvram: loaded module regularly", m)
self.model_accelerated = True self.model_accelerated = True
if is_intel_xpu() and not args.disable_ipex_optimize: if is_intel_xpu() and not args.disable_ipex_optimize:
@ -305,7 +326,11 @@ class LoadedModel:
def model_unload(self): def model_unload(self):
if self.model_accelerated: if self.model_accelerated:
accelerate.hooks.remove_hook_from_submodules(self.real_model) for m in self.real_model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
self.model_accelerated = False self.model_accelerated = False
self.model.unpatch_model(self.model.offload_device) self.model.unpatch_model(self.model.offload_device)
@ -398,14 +423,14 @@ def load_models_gpu(models, memory_required=0):
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev) model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev) current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM vram_set_state = VRAMState.LOW_VRAM
else: else:
lowvram_model_memory = 0 lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 256 * 1024 * 1024 lowvram_model_memory = 64 * 1024 * 1024
cur_loaded_model = loaded_model.model_load(lowvram_model_memory) cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
current_loaded_models.insert(0, loaded_model) current_loaded_models.insert(0, loaded_model)
@ -430,6 +455,13 @@ def dtype_size(dtype):
dtype_size = 4 dtype_size = 4
if dtype == torch.float16 or dtype == torch.bfloat16: if dtype == torch.float16 or dtype == torch.bfloat16:
dtype_size = 2 dtype_size = 2
elif dtype == torch.float32:
dtype_size = 4
else:
try:
dtype_size = dtype.itemsize
except: #Old pytorch doesn't have .itemsize
pass
return dtype_size return dtype_size
def unet_offload_device(): def unet_offload_device():
@ -459,10 +491,30 @@ def unet_inital_load_device(parameters, dtype):
def unet_dtype(device=None, model_params=0): def unet_dtype(device=None, model_params=0):
if args.bf16_unet: if args.bf16_unet:
return torch.bfloat16 return torch.bfloat16
if args.fp16_unet:
return torch.float16
if args.fp8_e4m3fn_unet:
return torch.float8_e4m3fn
if args.fp8_e5m2_unet:
return torch.float8_e5m2
if should_use_fp16(device=device, model_params=model_params): if should_use_fp16(device=device, model_params=model_params):
return torch.float16 return torch.float16
return torch.float32 return torch.float32
# None means no manual cast
def unet_manual_cast(weight_dtype, inference_device):
if weight_dtype == torch.float32:
return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
if fp16_supported and weight_dtype == torch.float16:
return None
if fp16_supported:
return torch.float16
else:
return torch.float32
def text_encoder_offload_device(): def text_encoder_offload_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
@ -492,12 +544,23 @@ def text_encoder_dtype(device=None):
elif args.fp32_text_enc: elif args.fp32_text_enc:
return torch.float32 return torch.float32
if is_device_cpu(device):
return torch.float16
if should_use_fp16(device, prioritize_performance=False): if should_use_fp16(device, prioritize_performance=False):
return torch.float16 return torch.float16
else: else:
return torch.float32 return torch.float32
def intermediate_device():
if args.gpu_only:
return get_torch_device()
else:
return torch.device("cpu")
def vae_device(): def vae_device():
if args.cpu_vae:
return torch.device("cpu")
return get_torch_device() return get_torch_device()
def vae_offload_device(): def vae_offload_device():
@ -515,6 +578,22 @@ def get_autocast_device(dev):
return dev.type return dev.type
return "cuda" return "cuda"
def supports_dtype(device, dtype): #TODO
if dtype == torch.float32:
return True
if is_device_cpu(device):
return False
if dtype == torch.float16:
return True
if dtype == torch.bfloat16:
return True
return False
def device_supports_non_blocking(device):
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking
return True
def cast_to_device(tensor, device, dtype, copy=False): def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = False device_supports_cast = False
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
@ -525,15 +604,17 @@ def cast_to_device(tensor, device, dtype, copy=False):
elif is_intel_xpu(): elif is_intel_xpu():
device_supports_cast = True device_supports_cast = True
non_blocking = device_supports_non_blocking(device)
if device_supports_cast: if device_supports_cast:
if copy: if copy:
if tensor.device == device: if tensor.device == device:
return tensor.to(dtype, copy=copy) return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
return tensor.to(device, copy=copy).to(dtype) return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else: else:
return tensor.to(device).to(dtype) return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else: else:
return tensor.to(dtype).to(device, copy=copy) return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
def xformers_enabled(): def xformers_enabled():
global directml_enabled global directml_enabled
@ -687,11 +768,11 @@ def soft_empty_cache(force=False):
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
def resolve_lowvram_weight(weight, model, key): def unload_all_models():
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. free_memory(1e30, get_torch_device())
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = utils.get_attr(model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]] def resolve_lowvram_weight(weight, model, key): #TODO: remove
return weight return weight
#TODO: might be cleaner to put this somewhere else #TODO: might be cleaner to put this somewhere else

View File

@ -28,13 +28,9 @@ class ModelPatcher:
if self.size > 0: if self.size > 0:
return self.size return self.size
model_sd = self.model.state_dict() model_sd = self.model.state_dict()
size = 0 self.size = model_management.module_size(self.model)
for k in model_sd:
t = model_sd[k]
size += t.nelement() * t.element_size()
self.size = size
self.model_keys = set(model_sd.keys()) self.model_keys = set(model_sd.keys())
return size return self.size
def clone(self): def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
@ -55,11 +51,18 @@ class ModelPatcher:
def memory_required(self, input_shape): def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape) return self.model.memory_required(input_shape=input_shape)
def set_model_sampler_cfg_function(self, sampler_cfg_function): def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3: if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else: else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True
def set_model_unet_function_wrapper(self, unet_wrapper_function): def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function self.model_options["model_function_wrapper"] = unet_wrapper_function
@ -70,13 +73,17 @@ class ModelPatcher:
to["patches"] = {} to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch] to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_patch_replace(self, patch, name, block_name, number): def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
to = self.model_options["transformer_options"] to = self.model_options["transformer_options"]
if "patches_replace" not in to: if "patches_replace" not in to:
to["patches_replace"] = {} to["patches_replace"] = {}
if name not in to["patches_replace"]: if name not in to["patches_replace"]:
to["patches_replace"][name] = {} to["patches_replace"][name] = {}
to["patches_replace"][name][(block_name, number)] = patch if transformer_index is not None:
block = (block_name, number, transformer_index)
else:
block = (block_name, number)
to["patches_replace"][name][block] = patch
def set_model_attn1_patch(self, patch): def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch") self.set_model_patch(patch, "attn1_patch")
@ -84,11 +91,11 @@ class ModelPatcher:
def set_model_attn2_patch(self, patch): def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch") self.set_model_patch(patch, "attn2_patch")
def set_model_attn1_replace(self, patch, block_name, number): def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
self.set_model_patch_replace(patch, "attn1", block_name, number) self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
def set_model_attn2_replace(self, patch, block_name, number): def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
self.set_model_patch_replace(patch, "attn2", block_name, number) self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
def set_model_attn1_output_patch(self, patch): def set_model_attn1_output_patch(self, patch):
self.set_model_patch(patch, "attn1_output_patch") self.set_model_patch(patch, "attn1_output_patch")
@ -167,40 +174,41 @@ class ModelPatcher:
sd.pop(k) sd.pop(k)
return sd return sd
def patch_model(self, device_to=None): def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches: for k in self.object_patches:
old = getattr(self.model, k) old = getattr(self.model, k)
if k not in self.object_patches_backup: if k not in self.object_patches_backup:
self.object_patches_backup[k] = old self.object_patches_backup[k] = old
setattr(self.model, k, self.object_patches[k]) setattr(self.model, k, self.object_patches[k])
model_sd = self.model_state_dict() if patch_weights:
for key in self.patches: model_sd = self.model_state_dict()
if key not in model_sd: for key in self.patches:
print("could not patch. key doesn't exist in model:", key) if key not in model_sd:
continue print("could not patch. key doesn't exist in model:", key)
continue
weight = model_sd[key] weight = model_sd[key]
inplace_update = self.weight_inplace_update inplace_update = self.weight_inplace_update
if key not in self.backup: if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
if device_to is not None:
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
if inplace_update:
utils.copy_to_param(self.model, key, out_weight)
else:
utils.set_attr(self.model, key, out_weight)
del temp_weight
if device_to is not None: if device_to is not None:
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True) self.model.to(device_to)
else: self.current_device = device_to
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
if inplace_update:
utils.copy_to_param(self.model, key, out_weight)
else:
utils.set_attr(self.model, key, out_weight)
del temp_weight
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
return self.model return self.model
@ -217,13 +225,19 @@ class ModelPatcher:
v = (self.calculate_weight(v[1:], v[0].clone(), key), ) v = (self.calculate_weight(v[1:], v[0].clone(), key), )
if len(v) == 1: if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
w1 = v[0] w1 = v[0]
if alpha != 0.0: if alpha != 0.0:
if w1.shape != weight.shape: if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else: else:
weight += alpha * model_management.cast_to_device(w1, weight.device, weight.dtype) weight += alpha * model_management.cast_to_device(w1, weight.device, weight.dtype)
elif len(v) == 4: #lora/locon elif patch_type == "lora": #lora/locon
mat1 = model_management.cast_to_device(v[0], weight.device, torch.float32) mat1 = model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = model_management.cast_to_device(v[1], weight.device, torch.float32) mat2 = model_management.cast_to_device(v[1], weight.device, torch.float32)
if v[2] is not None: if v[2] is not None:
@ -237,7 +251,7 @@ class ModelPatcher:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e: except Exception as e:
print("ERROR", key, e) print("ERROR", key, e)
elif len(v) == 8: #lokr elif patch_type == "lokr":
w1 = v[0] w1 = v[0]
w2 = v[1] w2 = v[1]
w1_a = v[3] w1_a = v[3]
@ -276,7 +290,7 @@ class ModelPatcher:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e: except Exception as e:
print("ERROR", key, e) print("ERROR", key, e)
else: #loha elif patch_type == "loha":
w1a = v[0] w1a = v[0]
w1b = v[1] w1b = v[1]
if v[2] is not None: if v[2] is not None:
@ -305,6 +319,18 @@ class ModelPatcher:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e: except Exception as e:
print("ERROR", key, e) print("ERROR", key, e)
elif patch_type == "glora":
if v[4] is not None:
alpha *= v[4] / v[0].shape[0]
a1 = model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
b2 = model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
else:
print("patch type not recognized", patch_type, key)
return weight return weight

View File

@ -22,10 +22,17 @@ class V_PREDICTION(EPS):
class ModelSamplingDiscrete(torch.nn.Module): class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None): def __init__(self, model_config=None):
super().__init__() super().__init__()
beta_schedule = "linear"
if model_config is not None: if model_config is not None:
beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule) sampling_settings = model_config.sampling_settings
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) else:
sampling_settings = {}
beta_schedule = sampling_settings.get("beta_schedule", "linear")
linear_start = sampling_settings.get("linear_start", 0.00085)
linear_end = sampling_settings.get("linear_end", 0.012)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
self.sigma_data = 1.0 self.sigma_data = 1.0
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,

View File

@ -6,7 +6,7 @@ import hashlib
import math import math
import random import random
from PIL import Image, ImageOps from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
import numpy as np import numpy as np
import safetensors.torch import safetensors.torch
@ -930,8 +930,8 @@ class GLIGENTextBoxApply:
return (c, ) return (c, )
class EmptyLatentImage: class EmptyLatentImage:
def __init__(self, device="cpu"): def __init__(self):
self.device = device self.device = comfy.model_management.intermediate_device()
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -944,7 +944,7 @@ class EmptyLatentImage:
CATEGORY = "latent" CATEGORY = "latent"
def generate(self, width, height, batch_size=1): def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 4, height // 8, width // 8]) latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
return ({"samples":latent}, ) return ({"samples":latent}, )
@ -1395,17 +1395,30 @@ class LoadImage:
FUNCTION = "load_image" FUNCTION = "load_image"
def load_image(self, image): def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image) image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path) img = Image.open(image_path)
i = ImageOps.exif_transpose(i) output_images = []
image = i.convert("RGB") output_masks = []
image = np.array(image).astype(np.float32) / 255.0 for i in ImageSequence.Iterator(img):
image = torch.from_numpy(image)[None,] i = ImageOps.exif_transpose(i)
if 'A' in i.getbands(): image = i.convert("RGB")
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask) image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
if len(output_images) > 1:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else: else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") output_image = output_images[0]
return (image, mask.unsqueeze(0)) output_mask = output_masks[0]
return (output_image, output_mask)
@classmethod @classmethod
def IS_CHANGED(s, image): def IS_CHANGED(s, image):
@ -1463,13 +1476,10 @@ class LoadImageMask:
return m.digest().hex() return m.digest().hex()
@classmethod @classmethod
def VALIDATE_INPUTS(s, image, channel): def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image): if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image) return "Invalid image file: {}".format(image)
if channel not in s._color_channels:
return "Invalid color channel: {}".format(channel)
return True return True
class ImageScale: class ImageScale:

View File

@ -1,40 +1,115 @@
import torch import torch
from contextlib import contextmanager from contextlib import contextmanager
import comfy.model_management
class Linear(torch.nn.Linear): def cast_bias_weight(s, input):
def reset_parameters(self): bias = None
return None non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
return weight, bias
class Conv2d(torch.nn.Conv2d):
def reset_parameters(self):
return None
class Conv3d(torch.nn.Conv3d): class disable_weight_init:
def reset_parameters(self): class Linear(torch.nn.Linear):
return None comfy_cast_weights = False
def reset_parameters(self):
return None
def conv_nd(dims, *args, **kwargs): def forward_comfy_cast_weights(self, input):
if dims == 2: weight, bias = cast_bias_weight(self, input)
return Conv2d(*args, **kwargs) return torch.nn.functional.linear(input, weight, bias)
elif dims == 3:
return Conv3d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
@contextmanager def forward(self, *args, **kwargs):
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way if self.comfy_cast_weights:
old_torch_nn_linear = torch.nn.Linear return self.forward_comfy_cast_weights(*args, **kwargs)
force_device = device else:
force_dtype = dtype return super().forward(*args, **kwargs)
def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
if force_device is not None:
device = force_device
if force_dtype is not None:
dtype = force_dtype
return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
torch.nn.Linear = linear_with_dtype class Conv2d(torch.nn.Conv2d):
try: comfy_cast_weights = False
yield def reset_parameters(self):
finally: return None
torch.nn.Linear = old_torch_nn_linear
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv3d(torch.nn.Conv3d):
comfy_cast_weights = False
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class GroupNorm(torch.nn.GroupNorm):
comfy_cast_weights = False
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm):
comfy_cast_weights = False
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@classmethod
def conv_nd(s, dims, *args, **kwargs):
if dims == 2:
return s.Conv2d(*args, **kwargs)
elif dims == 3:
return s.Conv3d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
class manual_cast(disable_weight_init):
class Linear(disable_weight_init.Linear):
comfy_cast_weights = True
class Conv2d(disable_weight_init.Conv2d):
comfy_cast_weights = True
class Conv3d(disable_weight_init.Conv3d):
comfy_cast_weights = True
class GroupNorm(disable_weight_init.GroupNorm):
comfy_cast_weights = True
class LayerNorm(disable_weight_init.LayerNorm):
comfy_cast_weights = True

View File

@ -0,0 +1,9 @@
from importlib.resources import path
import os
def get_editable_resource_path(caller_file, *package_path):
filename = os.path.join(os.path.dirname(os.path.realpath(caller_file)), package_path[-1])
if not os.path.exists(filename):
filename = path(*package_path)
return filename

View File

@ -47,7 +47,8 @@ def convert_cond(cond):
temp = c[1].copy() temp = c[1].copy()
model_conds = temp.get("model_conds", {}) model_conds = temp.get("model_conds", {})
if c[0] is not None: if c[0] is not None:
model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds temp["model_conds"] = model_conds
out.append(temp) out.append(temp)
return out return out
@ -98,10 +99,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler = samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) sampler = samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.cpu() samples = samples.to(model_management.intermediate_device())
cleanup_additional_models(models) cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))) cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
@ -111,8 +112,8 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent
sigmas = sigmas.to(model.load_device) sigmas = sigmas.to(model.load_device)
samples = samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.cpu() samples = samples.to(model_management.intermediate_device())
cleanup_additional_models(models) cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))) cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples return samples

View File

@ -1,256 +1,264 @@
from .k_diffusion import sampling as k_diffusion_sampling from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc from .extra_samplers import uni_pc
import torch import torch
import collections
from . import model_management from . import model_management
import math import math
def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0
if 'timestep_start' in conds:
timestep_start = conds['timestep_start']
if timestep_in[0] > timestep_start:
return None
if 'timestep_end' in conds:
timestep_end = conds['timestep_end']
if timestep_in[0] < timestep_end:
return None
if 'area' in conds:
area = conds['area']
if 'strength' in conds:
strength = conds['strength']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
if 'mask' in conds:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength = 1.0
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
else:
mask = torch.ones_like(input_x)
mult = mask * strength
if 'mask' not in conds:
rr = 8
if area[2] != 0:
for t in range(rr):
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
if (area[0] + area[2]) < x_in.shape[2]:
for t in range(rr):
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
if area[3] != 0:
for t in range(rr):
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr):
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditioning = {}
model_conds = conds["model_conds"]
for c in model_conds:
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
control = conds.get('control', None)
patches = None
if 'gligen' in conds:
gligen = conds['gligen']
patches = {}
gligen_type = gligen[0]
gligen_model = gligen[1]
if gligen_type == "position":
gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
else:
gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
patches['middle_patch'] = [gligen_patch]
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches'])
return cond_obj(input_x, mult, conditioning, area, control, patches)
def cond_equal_size(c1, c2):
if c1 is c2:
return True
if c1.keys() != c2.keys():
return False
for k in c1:
if not c1[k].can_concat(c2[k]):
return False
return True
def can_concat_cond(c1, c2):
if c1.input_x.shape != c2.input_x.shape:
return False
def objects_concatable(obj1, obj2):
if (obj1 is None) != (obj2 is None):
return False
if obj1 is not None:
if obj1 is not obj2:
return False
return True
if not objects_concatable(c1.control, c2.control):
return False
if not objects_concatable(c1.patches, c2.patches):
return False
return cond_equal_size(c1.conditioning, c2.conditioning)
def cond_cat(c_list):
c_crossattn = []
c_concat = []
c_adm = []
crossattn_max_len = 0
temp = {}
for x in c_list:
for k in x:
cur = temp.get(k, [])
cur.append(x[k])
temp[k] = cur
out = {}
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
return out
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in) * 1e-37
out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in) * 1e-37
COND = 0
UNCOND = 1
to_run = []
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, COND)]
if uncond is not None:
for x in uncond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, UNCOND)]
while len(to_run) > 0:
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = model_management.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) < free_memory:
to_batch = batch_amount
break
input_x = []
mult = []
c = []
cond_or_uncond = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
control = p.control
patches = p.patches
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
if patches is not None:
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
else:
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep
c['transformer_options'] = transformer_options
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
for o in range(batch_chunks):
if cond_or_uncond[o] == COND:
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
else:
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
del mult
out_cond /= out_count
del out_count
out_uncond /= out_uncond_count
del out_uncond_count
return out_cond, out_uncond
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns denoised #Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def get_area_and_mult(conds, x_in, timestep_in): if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
area = (x_in.shape[2], x_in.shape[3], 0, 0) uncond_ = None
strength = 1.0
if 'timestep_start' in conds:
timestep_start = conds['timestep_start']
if timestep_in[0] > timestep_start:
return None
if 'timestep_end' in conds:
timestep_end = conds['timestep_end']
if timestep_in[0] < timestep_end:
return None
if 'area' in conds:
area = conds['area']
if 'strength' in conds:
strength = conds['strength']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
if 'mask' in conds:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength = 1.0
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
else:
mask = torch.ones_like(input_x)
mult = mask * strength
if 'mask' not in conds:
rr = 8
if area[2] != 0:
for t in range(rr):
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
if (area[0] + area[2]) < x_in.shape[2]:
for t in range(rr):
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
if area[3] != 0:
for t in range(rr):
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr):
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditionning = {}
model_conds = conds["model_conds"]
for c in model_conds:
conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
control = None
if 'control' in conds:
control = conds['control']
patches = None
if 'gligen' in conds:
gligen = conds['gligen']
patches = {}
gligen_type = gligen[0]
gligen_model = gligen[1]
if gligen_type == "position":
gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
else:
gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
patches['middle_patch'] = [gligen_patch]
return (input_x, mult, conditionning, area, control, patches)
def cond_equal_size(c1, c2):
if c1 is c2:
return True
if c1.keys() != c2.keys():
return False
for k in c1:
if not c1[k].can_concat(c2[k]):
return False
return True
def can_concat_cond(c1, c2):
if c1[0].shape != c2[0].shape:
return False
#control
if (c1[4] is None) != (c2[4] is None):
return False
if c1[4] is not None:
if c1[4] is not c2[4]:
return False
#patches
if (c1[5] is None) != (c2[5] is None):
return False
if (c1[5] is not None):
if c1[5] is not c2[5]:
return False
return cond_equal_size(c1[2], c2[2])
def cond_cat(c_list):
c_crossattn = []
c_concat = []
c_adm = []
crossattn_max_len = 0
temp = {}
for x in c_list:
for k in x:
cur = temp.get(k, [])
cur.append(x[k])
temp[k] = cur
out = {}
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
return out
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in) * 1e-37
out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in) * 1e-37
COND = 0
UNCOND = 1
to_run = []
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, COND)]
if uncond is not None:
for x in uncond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, UNCOND)]
while len(to_run) > 0:
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = model_management.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) < free_memory:
to_batch = batch_amount
break
input_x = []
mult = []
c = []
cond_or_uncond = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x += [p[0]]
mult += [p[1]]
c += [p[2]]
area += [p[3]]
cond_or_uncond += [o[1]]
control = p[4]
patches = p[5]
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
if patches is not None:
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
else:
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep
c['transformer_options'] = transformer_options
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
for o in range(batch_chunks):
if cond_or_uncond[o] == COND:
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
else:
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
del mult
out_cond /= out_count
del out_count
out_uncond /= out_uncond_count
del out_uncond_count
return out_cond, out_uncond
if math.isclose(cond_scale, 1.0):
uncond = None
cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
return x - model_options["sampler_cfg_function"](args)
else: else:
return uncond + (cond - uncond) * cond_scale uncond_ = uncond
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)
return cfg_result
class CFGNoisePredictor(torch.nn.Module): class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model): def __init__(self, model):
@ -272,10 +280,7 @@ class KSamplerX0Inpaint(torch.nn.Module):
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed) out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
if denoise_mask is not None: if denoise_mask is not None:
out *= denoise_mask out = out * denoise_mask + self.latent_image * latent_mask
if denoise_mask is not None:
out += self.latent_image * latent_mask
return out return out
def simple_scheduler(model, steps): def simple_scheduler(model, steps):
@ -590,6 +595,13 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
calculate_start_end_timesteps(model, negative) calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model, positive) calculate_start_end_timesteps(model, positive)
if latent_image is not None:
latent_image = model.process_latent_in(latent_image)
if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
#make sure each cond area has an opposite one with the same area #make sure each cond area has an opposite one with the same area
for c in positive: for c in positive:
create_cond_with_same_area_if_none(negative, c) create_cond_with_same_area_if_none(negative, c)
@ -601,13 +613,6 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if latent_image is not None:
latent_image = model.process_latent_in(latent_image)
if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
@ -630,7 +635,7 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
elif scheduler_name == "sgm_uniform": elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model, steps, sgm=True) sigmas = normal_scheduler(model, steps, sgm=True)
else: else:
print("error invalid scheduler", self.scheduler) print("error invalid scheduler", scheduler_name)
return sigmas return sigmas
def sampler_object(name): def sampler_object(name):

View File

@ -148,12 +148,14 @@ class CLIP:
return self.patcher.get_key_patches() return self.patcher.get_key_patches()
class VAE: class VAE:
def __init__(self, sd=None, device=None, config=None): def __init__(self, sd=None, device=None, config=None, dtype=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd) sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
self.downscale_ratio = 8
self.latent_channels = 4
if config is None: if config is None:
if "decoder.mid.block_1.mix_factor" in sd: if "decoder.mid.block_1.mix_factor" in sd:
@ -169,6 +171,11 @@ class VAE:
else: else:
#default SD1.x/SD2.x VAE parameters #default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else: else:
self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = AutoencoderKL(**(config['params']))
@ -185,8 +192,11 @@ class VAE:
device = model_management.vae_device() device = model_management.vae_device()
self.device = device self.device = device
offload_device = model_management.vae_offload_device() offload_device = model_management.vae_offload_device()
self.vae_dtype = model_management.vae_dtype() if dtype is None:
dtype = model_management.vae_dtype()
self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype) self.first_stage_model.to(self.vae_dtype)
self.output_device = model_management.intermediate_device()
self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
@ -198,9 +208,9 @@ class VAE:
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
output = torch.clamp(( output = torch.clamp((
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar)) utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar))
/ 3.0) / 2.0, min=0.0, max=1.0) / 3.0) / 2.0, min=0.0, max=1.0)
return output return output
@ -211,9 +221,9 @@ class VAE:
pbar = utils.ProgressBar(steps) pbar = utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples /= 3.0 samples /= 3.0
return samples return samples
@ -225,15 +235,15 @@ class VAE:
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu") pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device)
for x in range(0, samples_in.shape[0], batch_number): for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0) pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0)
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in) pixel_samples = self.decode_tiled_(samples_in)
pixel_samples = pixel_samples.cpu().movedim(1,-1) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
@ -249,10 +259,10 @@ class VAE:
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
for x in range(0, pixel_samples.shape[0], batch_number): for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float() samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
@ -429,11 +439,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
parameters = utils.calculate_parameters(sd, "model.diffusion_model.") parameters = utils.calculate_parameters(sd, "model.diffusion_model.")
unet_dtype = model_management.unet_dtype(model_params=parameters) unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
class WeightsLoader(torch.nn.Module): class WeightsLoader(torch.nn.Module):
pass pass
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype) model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
model_config.set_manual_cast(manual_cast_dtype)
if model_config is None: if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
@ -466,7 +480,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
print("left over keys:", left_over) print("left over keys:", left_over)
if output_model: if output_model:
_model_patcher = model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) _model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"): if inital_load_device != torch.device("cpu"):
print("loaded straight to GPU") print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher) model_management.load_model_gpu(model_patcher)
@ -477,6 +491,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def load_unet_state_dict(sd): #load unet in diffusers format def load_unet_state_dict(sd): #load unet in diffusers format
parameters = utils.calculate_parameters(sd) parameters = utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters) unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
if "input_blocks.0.0.weight" in sd: #ldm if "input_blocks.0.0.weight" in sd: #ldm
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype) model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
if model_config is None: if model_config is None:
@ -497,13 +514,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format
else: else:
print(diffusers_keys[k], k) print(diffusers_keys[k], k)
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
model_config.set_manual_cast(manual_cast_dtype)
model = model_config.get_model(new_sd, "") model = model_config.get_model(new_sd, "")
model = model.to(offload_device) model = model.to(offload_device)
model.load_model_weights(new_sd, "") model.load_model_weights(new_sd, "")
left_over = sd.keys() left_over = sd.keys()
if len(left_over) > 0: if len(left_over) > 0:
print("left over keys in unet:", left_over) print("left over keys in unet:", left_over)
return model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
def load_unet(unet_path): def load_unet(unet_path):
sd = utils.load_torch_file(unet_path) sd = utils.load_torch_file(unet_path)

View File

@ -1,6 +1,6 @@
import os import os
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils from transformers import CLIPTokenizer
from . import ops from . import ops
import torch import torch
import traceback import traceback
@ -8,6 +8,8 @@ import zipfile
from . import model_management from . import model_management
from pkg_resources import resource_filename from pkg_resources import resource_filename
import contextlib import contextlib
from . import clip_model
import json
def gen_empty_tokens(special_tokens, length): def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None) start_token = special_tokens.get("start", None)
@ -38,7 +40,7 @@ class ClipTokenWeightEncoder:
out, pooled = self.encode(to_encode) out, pooled = self.encode(to_encode)
if pooled is not None: if pooled is not None:
first_pooled = pooled[0:1].cpu() first_pooled = pooled[0:1].to(model_management.intermediate_device())
else: else:
first_pooled = pooled first_pooled = pooled
@ -55,8 +57,8 @@ class ClipTokenWeightEncoder:
output.append(z) output.append(z)
if (len(output) == 0): if (len(output) == 0):
return out[-1:].cpu(), first_pooled return out[-1:].to(model_management.intermediate_device()), first_pooled
return torch.cat(output, dim=-2).cpu(), first_pooled return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)""" """Uses the CLIP transformer encoder for text (from huggingface)"""
@ -66,33 +68,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"hidden" "hidden"
] ]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=clip_model.CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig, special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
super().__init__() super().__init__()
assert layer in self.LAYERS assert layer in self.LAYERS
self.num_layers = 12
if textmodel_path is not None:
self.transformer = model_class.from_pretrained(textmodel_path)
else:
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
if not os.path.exists(textmodel_json_config):
textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json')
config = config_class.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers
with ops.use_comfy_ops(device, dtype):
with modeling_utils.no_init_weights():
self.transformer = model_class(config)
self.inner_name = inner_name if textmodel_json_config is None:
if dtype is not None: textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
self.transformer.to(dtype) if not os.path.exists(textmodel_json_config):
inner_model = getattr(self.transformer, self.inner_name) textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json')
if hasattr(inner_model, "embeddings"):
inner_model.embeddings.to(torch.float32) with open(textmodel_json_config) as f:
else: config = json.load(f)
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
self.transformer = model_class(config, dtype, device, ops.manual_cast)
self.num_layers = self.transformer.num_layers
self.max_length = max_length self.max_length = max_length
if freeze: if freeze:
@ -107,7 +97,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer_norm_hidden_state = layer_norm_hidden_state self.layer_norm_hidden_state = layer_norm_hidden_state
if layer == "hidden": if layer == "hidden":
assert layer_idx is not None assert layer_idx is not None
assert abs(layer_idx) <= self.num_layers assert abs(layer_idx) < self.num_layers
self.clip_layer(layer_idx) self.clip_layer(layer_idx)
self.layer_default = (self.layer, self.layer_idx) self.layer_default = (self.layer, self.layer_idx)
@ -118,7 +108,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
param.requires_grad = False param.requires_grad = False
def clip_layer(self, layer_idx): def clip_layer(self, layer_idx):
if abs(layer_idx) >= self.num_layers: if abs(layer_idx) > self.num_layers:
self.layer = "last" self.layer = "last"
else: else:
self.layer = "hidden" self.layer = "hidden"
@ -173,41 +163,31 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device) tokens = torch.LongTensor(tokens).to(device)
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32: attention_mask = None
precision_scope = torch.autocast if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs[0]
else: else:
precision_scope = lambda a, dtype: contextlib.nullcontext(a) z = outputs[1]
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32): if outputs[2] is not None:
attention_mask = None pooled_output = outputs[2].float()
if self.enable_attention_masks: else:
attention_mask = torch.zeros_like(tokens) pooled_output = None
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden") if self.text_projection is not None and pooled_output is not None:
self.transformer.set_input_embeddings(backup_embeds) pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
if hasattr(outputs, "pooler_output"):
pooled_output = outputs.pooler_output.float()
else:
pooled_output = None
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output return z.float(), pooled_output
def encode(self, tokens): def encode(self, tokens):

View File

@ -4,15 +4,15 @@ from . import sd1_clip
import os import os
class SD2ClipHModel(sd1_clip.SDClipModel): class SD2ClipHModel(sd1_clip.SDClipModel):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
if layer == "penultimate": if layer == "penultimate":
layer="hidden" layer="hidden"
layer_idx=23 layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
if not os.path.exists(textmodel_json_config): if not os.path.exists(textmodel_json_config):
textmodel_json_config = resource_filename('comfy', 'sd2_clip_config.json') textmodel_json_config = resource_filename('comfy', 'sd2_clip_config.json')
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}) super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
class SD2ClipHTokenizer(sd1_clip.SDTokenizer): class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None): def __init__(self, tokenizer_path=None, embedding_directory=None):

View File

@ -3,13 +3,13 @@ import torch
import os import os
class SDXLClipG(sd1_clip.SDClipModel): class SDXLClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
if layer == "penultimate": if layer == "penultimate":
layer="hidden" layer="hidden"
layer_idx=-2 layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False) special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
def load_sd(self, sd): def load_sd(self, sd):
@ -37,7 +37,7 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module): class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None):
super().__init__() super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False) self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_g = SDXLClipG(device=device, dtype=dtype) self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx): def clip_layer(self, layer_idx):

View File

@ -217,6 +217,16 @@ class SSD1B(SDXL):
"use_temporal_attention": False, "use_temporal_attention": False,
} }
class Segmind_Vega(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 1, 1, 2, 2],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
class SVD_img2vid(supported_models_base.BASE): class SVD_img2vid(supported_models_base.BASE):
unet_config = { unet_config = {
"model_channels": 320, "model_channels": 320,
@ -242,5 +252,59 @@ class SVD_img2vid(supported_models_base.BASE):
def clip_target(self): def clip_target(self):
return None return None
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B] class Stable_Zero123(supported_models_base.BASE):
unet_config = {
"context_dim": 768,
"model_channels": 320,
"use_linear_in_transformer": False,
"adm_in_channels": None,
"use_temporal_attention": False,
"in_channels": 8,
}
unet_extra_config = {
"num_heads": 8,
"num_head_channels": -1,
}
clip_vision_prefix = "cond_stage_model.model.visual."
latent_format = latent_formats.SD15
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
return out
def clip_target(self):
return None
class SD_X4Upscaler(SD20):
unet_config = {
"context_dim": 1024,
"model_channels": 256,
'in_channels': 7,
"use_linear_in_transformer": True,
"adm_in_channels": None,
"use_temporal_attention": False,
}
unet_extra_config = {
"disable_self_attentions": [True, True, True, False],
"num_classes": 1000,
"num_heads": 8,
"num_head_channels": -1,
}
latent_format = latent_formats.SD_X4
sampling_settings = {
"linear_start": 0.0001,
"linear_end": 0.02,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SD_X4Upscaler(self, device=device)
return out
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@ -22,6 +22,8 @@ class BASE:
sampling_settings = {} sampling_settings = {}
latent_format = latent_formats.LatentFormat latent_format = latent_formats.LatentFormat
manual_cast_dtype = None
@classmethod @classmethod
def matches(s, unet_config): def matches(s, unet_config):
for k in s.unet_config: for k in s.unet_config:
@ -71,3 +73,5 @@ class BASE:
replace_prefix = {"": "first_stage_model."} replace_prefix = {"": "first_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def set_manual_cast(self, manual_cast_dtype):
self.manual_cast_dtype = manual_cast_dtype

View File

@ -7,9 +7,10 @@ import torch
import torch.nn as nn import torch.nn as nn
from .. import utils from .. import utils
from .. import ops
def conv(n_in, n_out, **kwargs): def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) return ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module): class Clamp(nn.Module):
def forward(self, x): def forward(self, x):
@ -19,7 +20,7 @@ class Block(nn.Module):
def __init__(self, n_in, n_out): def __init__(self, n_in, n_out):
super().__init__() super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() self.skip = ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU() self.fuse = nn.ReLU()
def forward(self, x): def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x)) return self.fuse(self.conv(x) + self.skip(x))

View File

@ -378,7 +378,7 @@ def lanczos(samples, width, height):
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
result = torch.stack(images) result = torch.stack(images)
return result return result.to(samples.device, samples.dtype)
def common_upscale(samples, width, height, upscale_method, crop): def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center": if crop == "center":
@ -407,17 +407,17 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
@torch.inference_mode() @torch.inference_mode()
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None): def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu") output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device)
for b in range(samples.shape[0]): for b in range(samples.shape[0]):
s = samples[b:b+1] s = samples[b:b+1]
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu") out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu") out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
for y in range(0, s.shape[2], tile_y - overlap): for y in range(0, s.shape[2], tile_y - overlap):
for x in range(0, s.shape[3], tile_x - overlap): for x in range(0, s.shape[3], tile_x - overlap):
s_in = s[:,:,y:y+tile_y,x:x+tile_x] s_in = s[:,:,y:y+tile_y,x:x+tile_x]
ps = function(s_in).cpu() ps = function(s_in).to(output_device)
mask = torch.ones_like(ps) mask = torch.ones_like(ps)
feather = round(overlap * upscale_amount) feather = round(overlap * upscale_amount)
for t in range(feather): for t in range(feather):

View File

@ -291,7 +291,7 @@ class Canny:
def detect_edge(self, image, low_threshold, high_threshold): def detect_edge(self, image, low_threshold, high_threshold):
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold) output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1) img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
return (img_out,) return (img_out,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {

View File

@ -1,9 +1,9 @@
import comfy.samplers from comfy import samplers
import comfy.sample from comfy import sample
from comfy.k_diffusion import sampling as k_diffusion_sampling from comfy.k_diffusion import sampling as k_diffusion_sampling
from comfy.cmd import latent_preview from comfy.cmd import latent_preview
import torch import torch
import comfy.utils from comfy import utils
class BasicScheduler: class BasicScheduler:
@ -11,8 +11,9 @@ class BasicScheduler:
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": return {"required":
{"model": ("MODEL",), {"model": ("MODEL",),
"scheduler": (comfy.samplers.SCHEDULER_NAMES, ), "scheduler": (samplers.SCHEDULER_NAMES, ),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
} }
} }
RETURN_TYPES = ("SIGMAS",) RETURN_TYPES = ("SIGMAS",)
@ -20,8 +21,15 @@ class BasicScheduler:
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
def get_sigmas(self, model, scheduler, steps): def get_sigmas(self, model, scheduler, steps, denoise):
sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, steps).cpu() total_steps = steps
if denoise < 1.0:
total_steps = int(steps/denoise)
inner_model = model.patch_model(patch_weights=False)
sigmas = samplers.calculate_sigmas_scheduler(inner_model, scheduler, total_steps).cpu()
model.unpatch_model()
sigmas = sigmas[-(steps + 1):]
return (sigmas, ) return (sigmas, )
@ -87,6 +95,7 @@ class SDTurboScheduler:
return {"required": return {"required":
{"model": ("MODEL",), {"model": ("MODEL",),
"steps": ("INT", {"default": 1, "min": 1, "max": 10}), "steps": ("INT", {"default": 1, "min": 1, "max": 10}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
} }
} }
RETURN_TYPES = ("SIGMAS",) RETURN_TYPES = ("SIGMAS",)
@ -94,9 +103,12 @@ class SDTurboScheduler:
FUNCTION = "get_sigmas" FUNCTION = "get_sigmas"
def get_sigmas(self, model, steps): def get_sigmas(self, model, steps, denoise):
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps] start_step = 10 - int(10 * denoise)
sigmas = model.model.model_sampling.sigma(timesteps) timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
inner_model = model.patch_model(patch_weights=False)
sigmas = inner_model.model_sampling.sigma(timesteps)
model.unpatch_model()
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return (sigmas, ) return (sigmas, )
@ -159,7 +171,7 @@ class KSamplerSelect:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": return {"required":
{"sampler_name": (comfy.samplers.SAMPLER_NAMES, ), {"sampler_name": (samplers.SAMPLER_NAMES, ),
} }
} }
RETURN_TYPES = ("SAMPLER",) RETURN_TYPES = ("SAMPLER",)
@ -168,7 +180,7 @@ class KSamplerSelect:
FUNCTION = "get_sampler" FUNCTION = "get_sampler"
def get_sampler(self, sampler_name): def get_sampler(self, sampler_name):
sampler = comfy.samplers.sampler_object(sampler_name) sampler = samplers.sampler_object(sampler_name)
return (sampler, ) return (sampler, )
class SamplerDPMPP_2M_SDE: class SamplerDPMPP_2M_SDE:
@ -191,7 +203,7 @@ class SamplerDPMPP_2M_SDE:
sampler_name = "dpmpp_2m_sde" sampler_name = "dpmpp_2m_sde"
else: else:
sampler_name = "dpmpp_2m_sde_gpu" sampler_name = "dpmpp_2m_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type}) sampler = samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
return (sampler, ) return (sampler, )
@ -215,7 +227,7 @@ class SamplerDPMPP_SDE:
sampler_name = "dpmpp_sde" sampler_name = "dpmpp_sde"
else: else:
sampler_name = "dpmpp_sde_gpu" sampler_name = "dpmpp_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) sampler = samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
return (sampler, ) return (sampler, )
class SamplerCustom: class SamplerCustom:
@ -248,7 +260,7 @@ class SamplerCustom:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else: else:
batch_inds = latent["batch_index"] if "batch_index" in latent else None batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = comfy.sample.prepare_noise(latent_image, noise_seed, batch_inds) noise = sample.prepare_noise(latent_image, noise_seed, batch_inds)
noise_mask = None noise_mask = None
if "noise_mask" in latent: if "noise_mask" in latent:
@ -257,8 +269,8 @@ class SamplerCustom:
x0_output = {} x0_output = {}
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output) callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED disable_pbar = not utils.PROGRESS_BAR_ENABLED
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) samples = sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
out = latent.copy() out = latent.copy()
out["samples"] = samples out["samples"] = samples

View File

@ -2,9 +2,10 @@
import math import math
from einops import rearrange from einops import rearrange
import random # Use torch rng for consistency across generations
from torch import randint
def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int: def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
min_value = min(min_value, value) min_value = min(min_value, value)
# All big divisors of value (inclusive) # All big divisors of value (inclusive)
@ -12,8 +13,10 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter
ns = [value // i for i in divisors[:max_options]] # has at least 1 element ns = [value // i for i in divisors[:max_options]] # has at least 1 element
random.seed(counter) if len(ns) - 1 > 0:
idx = random.randint(0, len(ns) - 1) idx = randint(low=0, high=len(ns) - 1, size=(1,)).item()
else:
idx = 0
return ns[idx] return ns[idx]
@ -42,7 +45,6 @@ class HyperTile:
latent_tile_size = max(32, tile_size) // 8 latent_tile_size = max(32, tile_size) // 8
self.temp = None self.temp = None
self.counter = 1
def hypertile_in(q, k, v, extra_options): def hypertile_in(q, k, v, extra_options):
if q.shape[-1] in apply_to: if q.shape[-1] in apply_to:
@ -53,10 +55,8 @@ class HyperTile:
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1 factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1
nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter) nh = random_divisor(h, latent_tile_size * factor, swap_size)
self.counter += 1 nw = random_divisor(w, latent_tile_size * factor, swap_size)
nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter)
self.counter += 1
if nh * nw > 1: if nh * nw > 1:
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)

View File

@ -73,7 +73,7 @@ class SaveAnimatedWEBP:
OUTPUT_NODE = True OUTPUT_NODE = True
CATEGORY = "_for_testing" CATEGORY = "image/animation"
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
method = self.methods.get(method) method = self.methods.get(method)
@ -135,7 +135,7 @@ class SaveAnimatedPNG:
OUTPUT_NODE = True OUTPUT_NODE = True
CATEGORY = "_for_testing" CATEGORY = "image/animation"
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append filename_prefix += self.prefix_append

View File

@ -3,9 +3,7 @@ import torch
def reshape_latent_to(target_shape, latent): def reshape_latent_to(target_shape, latent):
if latent.shape[1:] != target_shape[1:]: if latent.shape[1:] != target_shape[1:]:
latent.movedim(1, -1)
latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center") latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center")
latent.movedim(-1, 1)
return comfy.utils.repeat_to_batch_size(latent, target_shape[0]) return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
@ -102,9 +100,32 @@ class LatentInterpolate:
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,) return (samples_out,)
class LatentBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "batch"
CATEGORY = "latent/batch"
def batch(self, samples1, samples2):
samples_out = samples1.copy()
s1 = samples1["samples"]
s2 = samples2["samples"]
if s1.shape[1:] != s2.shape[1:]:
s2 = comfy.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center")
s = torch.cat((s1, s2), dim=0)
samples_out["samples"] = s
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
return (samples_out,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd, "LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract, "LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply, "LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate, "LatentInterpolate": LatentInterpolate,
"LatentBatch": LatentBatch,
} }

View File

@ -7,6 +7,7 @@ from comfy.nodes.common import MAX_RESOLUTION
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
source = source.to(destination.device)
if resize_source: if resize_source:
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
@ -21,7 +22,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
if mask is None: if mask is None:
mask = torch.ones_like(source) mask = torch.ones_like(source)
else: else:
mask = mask.clone() mask = mask.to(destination.device, copy=True)
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0]) mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])

View File

@ -16,41 +16,19 @@ class LCM(comfy.model_sampling.EPS):
return c_out * x0 + c_skip * model_input return c_out * x0 + c_skip * model_input
class ModelSamplingDiscreteDistilled(torch.nn.Module): class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
original_timesteps = 50 original_timesteps = 50
def __init__(self): def __init__(self, model_config=None):
super().__init__() super().__init__(model_config)
self.sigma_data = 1.0
timesteps = 1000
beta_start = 0.00085
beta_end = 0.012
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2 self.skip_steps = self.num_timesteps // self.original_timesteps
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
self.skip_steps = timesteps // self.original_timesteps sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
for x in range(self.original_timesteps): for x in range(self.original_timesteps):
alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps]
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5 self.set_sigmas(sigmas_valid)
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma): def timestep(self, sigma):
log_sigma = sigma.log() log_sigma = sigma.log()
@ -65,14 +43,6 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module):
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp().to(timestep.device) return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item()
def rescale_zero_terminal_snr_sigmas(sigmas): def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1) alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
@ -121,7 +91,7 @@ class ModelSamplingDiscrete:
class ModelSamplingAdvanced(sampling_base, sampling_type): class ModelSamplingAdvanced(sampling_base, sampling_type):
pass pass
model_sampling = ModelSamplingAdvanced() model_sampling = ModelSamplingAdvanced(model.model.model_config)
if zsnr: if zsnr:
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas)) model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
@ -153,7 +123,7 @@ class ModelSamplingContinuousEDM:
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
pass pass
model_sampling = ModelSamplingAdvanced() model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_sigma_range(sigma_min, sigma_max) model_sampling.set_sigma_range(sigma_min, sigma_max)
m.add_object_patch("model_sampling", model_sampling) m.add_object_patch("model_sampling", model_sampling)
return (m, ) return (m, )

View File

@ -0,0 +1,53 @@
import torch
from comfy import sample
from comfy import samplers
class PerpNeg:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"empty_conditioning": ("CONDITIONING", ),
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, empty_conditioning, neg_scale):
m = model.clone()
nocond = sample.convert_cond(empty_conditioning)
def cfg_function(args):
model = args["model"]
noise_pred_pos = args["cond_denoised"]
noise_pred_neg = args["uncond_denoised"]
cond_scale = args["cond_scale"]
x = args["input"]
sigma = args["sigma"]
model_options = args["model_options"]
nocond_processed = samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
(noise_pred_nocond, _) = samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options)
pos = noise_pred_pos - noise_pred_nocond
neg = noise_pred_neg - noise_pred_nocond
perp = ((torch.mul(pos, neg).sum())/(torch.norm(neg)**2)) * neg
perp_neg = perp * neg_scale
cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg)
cfg_result = x - cfg_result
return cfg_result
m.set_model_sampler_cfg_function(cfg_function)
return (m, )
NODE_CLASS_MAPPINGS = {
"PerpNeg": PerpNeg,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PerpNeg": "Perp-Neg",
}

View File

@ -226,7 +226,7 @@ class Sharpen:
batch_size, height, width, channels = image.shape batch_size, height, width, channels = image.shape
kernel_size = sharpen_radius * 2 + 1 kernel_size = sharpen_radius * 2 + 1
kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10) kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
center = kernel_size // 2 center = kernel_size // 2
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)

View File

@ -99,10 +99,40 @@ class LatentRebatch:
return (output_list,) return (output_list,)
class ImageRebatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "images": ("IMAGE",),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("IMAGE",)
INPUT_IS_LIST = True
OUTPUT_IS_LIST = (True, )
FUNCTION = "rebatch"
CATEGORY = "image/batch"
def rebatch(self, images, batch_size):
batch_size = batch_size[0]
output_list = []
all_images = []
for img in images:
for i in range(img.shape[0]):
all_images.append(img[i:i+1])
for i in range(0, len(all_images), batch_size):
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
return (output_list,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"RebatchLatents": LatentRebatch, "RebatchLatents": LatentRebatch,
"RebatchImages": ImageRebatch,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"RebatchLatents": "Rebatch Latents", "RebatchLatents": "Rebatch Latents",
} "RebatchImages": "Rebatch Images",
}

View File

@ -0,0 +1,168 @@
import torch
from torch import einsum
import torch.nn.functional as F
import math
from einops import rearrange, repeat
import os
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
from comfy import samplers
# from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output
def attention_basic_with_sim(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
h = heads
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
del q, k
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return (out, sim)
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
# reshape and GAP the attention map
_, hw1, hw2 = attn.shape
b, _, lh, lw = x0.shape
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
ratio = math.ceil(math.sqrt(lh * lw / hw1))
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
# Reshape
mask = (
mask.reshape(b, *mid_shape)
.unsqueeze(1)
.type(attn.dtype)
)
# Upsample
mask = F.interpolate(mask, (lh, lw))
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
blurred = blurred * mask + x0 * (1 - mask)
return blurred
def gaussian_blur_2d(img, kernel_size, sigma):
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
x_kernel = pdf / pdf.sum()
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
img = F.pad(img, padding, mode="reflect")
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
return img
class SelfAttentionGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}),
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, scale, blur_sigma):
m = model.clone()
attn_scores = None
# TODO: make this work properly with chunked batches
# currently, we can only save the attn from one UNet call
def attn_and_record(q, k, v, extra_options):
nonlocal attn_scores
# if uncond, save the attention scores
heads = extra_options["n_heads"]
cond_or_uncond = extra_options["cond_or_uncond"]
b = q.shape[0] // len(cond_or_uncond)
if 1 in cond_or_uncond:
uncond_index = cond_or_uncond.index(1)
# do the entire attention operation, but save the attention scores to attn_scores
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads)
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
n_slices = heads * b
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
return out
else:
return optimized_attention(q, k, v, heads=heads)
def post_cfg_function(args):
nonlocal attn_scores
uncond_attn = attn_scores
sag_scale = scale
sag_sigma = blur_sigma
sag_threshold = 1.0
model = args["model"]
uncond_pred = args["uncond_denoised"]
uncond = args["uncond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
model_options = args["model_options"]
x = args["input"]
# create the adversarially blurred image
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
degraded_noised = degraded + x - uncond_pred
# call into the UNet
(sag, _) = samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
return cfg_result + (degraded - sag) * sag_scale
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
# from diffusers:
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
return (m, )
NODE_CLASS_MAPPINGS = {
"SelfAttentionGuidance": SelfAttentionGuidance,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SelfAttentionGuidance": "Self-Attention Guidance",
}

View File

@ -0,0 +1,46 @@
import torch
from comfy import utils
class SD_4XUpscale_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "images": ("IMAGE",),
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/upscale_diffusion"
def encode(self, images, positive, negative, scale_ratio, noise_augmentation):
width = max(1, round(images.shape[-2] * scale_ratio))
height = max(1, round(images.shape[-3] * scale_ratio))
pixels = utils.common_upscale((images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center")
out_cp = []
out_cn = []
for t in positive:
n = [t[0], t[1].copy()]
n[1]['concat_image'] = pixels
n[1]['noise_augmentation'] = noise_augmentation
out_cp.append(n)
for t in negative:
n = [t[0], t[1].copy()]
n[1]['concat_image'] = pixels
n[1]['noise_augmentation'] = noise_augmentation
out_cn.append(n)
latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
return (out_cp, out_cn, {"samples":latent})
NODE_CLASS_MAPPINGS = {
"SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning,
}

View File

@ -0,0 +1,61 @@
import torch
import comfy.utils
from comfy.nodes.common import MAX_RESOLUTION
from comfy import utils
def camera_embeddings(elevation, azimuth):
elevation = torch.as_tensor([elevation])
azimuth = torch.as_tensor([azimuth])
embeddings = torch.stack(
[
torch.deg2rad(
(90 - elevation) - (90)
), # Zero123 polar is 90-elevation
torch.sin(torch.deg2rad(azimuth)),
torch.cos(torch.deg2rad(azimuth)),
torch.deg2rad(
90 - torch.full_like(elevation, 0)
),
], dim=-1).unsqueeze(1)
return embeddings
class StableZero123_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
t = vae.encode(encode_pixels)
cam_embeds = camera_embeddings(elevation, azimuth)
cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1)
positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
NODE_CLASS_MAPPINGS = {
"StableZero123_Conditioning": StableZero123_Conditioning,
}

View File

@ -3,6 +3,7 @@ torchaudio
torchvision torchvision
torchdiffeq>=0.2.3 torchdiffeq>=0.2.3
torchsde>=0.2.6 torchsde>=0.2.6
torchvision
einops>=0.6.0 einops>=0.6.0
open-clip-torch>=2.16.0 open-clip-torch>=2.16.0
transformers>=4.29.1 transformers>=4.29.1

View File

@ -28,18 +28,18 @@ version = '0.0.1'
""" """
The package index to the torch built with AMD ROCm. The package index to the torch built with AMD ROCm.
""" """
amd_torch_index = "https://download.pytorch.org/whl/rocm5.6" amd_torch_index = ("https://download.pytorch.org/whl/rocm5.6", "https://download.pytorch.org/whl/nightly/rocm5.7")
""" """
The package index to torch built with CUDA. The package index to torch built with CUDA.
Observe the CUDA version is in this URL. Observe the CUDA version is in this URL.
""" """
nvidia_torch_index = "https://download.pytorch.org/whl/cu121" nvidia_torch_index = ("https://download.pytorch.org/whl/cu121", "https://download.pytorch.org/whl/nightly/cu121")
""" """
The package index to torch built against CPU features. The package index to torch built against CPU features.
""" """
cpu_torch_index = "https://download.pytorch.org/whl/cpu" cpu_torch_index = ("https://download.pytorch.org/whl/cpu", "https://download.pytorch.org/whl/nightly/cpu")
# xformers not required for new torch # xformers not required for new torch
@ -102,11 +102,11 @@ def _is_linux_arm64():
def dependencies() -> List[str]: def dependencies() -> List[str]:
_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements.txt")).readlines() _dependencies = open(os.path.join(os.path.dirname(__file__), "requirements.txt")).readlines()
# todo: also add all plugin dependencies
_alternative_indices = [amd_torch_index, nvidia_torch_index] _alternative_indices = [amd_torch_index, nvidia_torch_index]
session = PipSession() session = PipSession()
index_urls = ['https://pypi.org/simple'] # (stable, nightly) tuple
index_urls = [('https://pypi.org/simple', 'https://pypi.org/simple')]
# prefer nvidia over AMD because AM5/iGPU systems will have a valid ROCm device # prefer nvidia over AMD because AM5/iGPU systems will have a valid ROCm device
if _is_nvidia(): if _is_nvidia():
index_urls += [nvidia_torch_index] index_urls += [nvidia_torch_index]
@ -118,6 +118,13 @@ def dependencies() -> List[str]:
if len(index_urls) == 1: if len(index_urls) == 1:
return _dependencies return _dependencies
if sys.version_info >= (3, 12):
# use the nightlies
index_urls = [nightly for (_, nightly) in index_urls]
_alternative_indices = [nightly for (_, nightly) in _alternative_indices]
else:
index_urls = [stable for (stable, _) in index_urls]
_alternative_indices = [stable for (stable, _) in _alternative_indices]
try: try:
# pip 23 # pip 23
finder = PackageFinder.create(LinkCollector(session, SearchScope([], index_urls, no_index=False)), finder = PackageFinder.create(LinkCollector(session, SearchScope([], index_urls, no_index=False)),
@ -149,7 +156,7 @@ setup(
description="", description="",
author="", author="",
version=version, version=version,
python_requires=">=3.9,<3.12", python_requires=">=3.9,<3.13",
# todo: figure out how to include the web directory to eventually let main live inside the package # todo: figure out how to include the web directory to eventually let main live inside the package
# todo: see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ for more about adding plugins # todo: see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ for more about adding plugins
packages=find_packages(exclude=[] if is_editable else ['custom_nodes']), packages=find_packages(exclude=[] if is_editable else ['custom_nodes']),

9
tests-ui/afterSetup.js Normal file
View File

@ -0,0 +1,9 @@
const { start } = require("./utils");
const lg = require("./utils/litegraph");
// Load things once per test file before to ensure its all warmed up for the tests
beforeAll(async () => {
lg.setup(global);
await start({ resetEnv: true });
lg.teardown(global);
});

View File

@ -2,8 +2,10 @@
const config = { const config = {
testEnvironment: "jsdom", testEnvironment: "jsdom",
setupFiles: ["./globalSetup.js"], setupFiles: ["./globalSetup.js"],
setupFilesAfterEnv: ["./afterSetup.js"],
clearMocks: true, clearMocks: true,
resetModules: true, resetModules: true,
testTimeout: 10000
}; };
module.exports = config; module.exports = config;

View File

@ -52,7 +52,7 @@ describe("extensions", () => {
const nodeNames = Object.keys(defs); const nodeNames = Object.keys(defs);
const nodeCount = nodeNames.length; const nodeCount = nodeNames.length;
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount); expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
for (let i = 0; i < nodeCount; i++) { for (let i = 0; i < 10; i++) {
// It should be send the JS class and the original JSON definition // It should be send the JS class and the original JSON definition
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0]; const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0];
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1]; const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1];
@ -133,7 +133,7 @@ describe("extensions", () => {
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2); expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2);
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1); expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1);
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2); expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2);
}); }, 15000);
it("allows custom nodeDefs and widgets to be registered", async () => { it("allows custom nodeDefs and widgets to be registered", async () => {
const widgetMock = jest.fn((node, inputName, inputData, app) => { const widgetMock = jest.fn((node, inputName, inputData, app) => {

View File

@ -1,7 +1,7 @@
// @ts-check // @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" /> /// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start, createDefaultWorkflow } = require("../utils"); const { start, createDefaultWorkflow, getNodeDef, checkBeforeAndAfterReload } = require("../utils");
const lg = require("../utils/litegraph"); const lg = require("../utils/litegraph");
describe("group node", () => { describe("group node", () => {
@ -273,7 +273,7 @@ describe("group node", () => {
let reroutes = []; let reroutes = [];
let prevNode = nodes.ckpt; let prevNode = nodes.ckpt;
for(let i = 0; i < 5; i++) { for (let i = 0; i < 5; i++) {
const reroute = ez.Reroute(); const reroute = ez.Reroute();
prevNode.outputs[0].connectTo(reroute.inputs[0]); prevNode.outputs[0].connectTo(reroute.inputs[0]);
prevNode = reroute; prevNode = reroute;
@ -283,7 +283,7 @@ describe("group node", () => {
const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]); const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]);
expect((await graph.toPrompt()).output).toEqual(getOutput()); expect((await graph.toPrompt()).output).toEqual(getOutput());
group.menu["Convert to nodes"].call(); group.menu["Convert to nodes"].call();
expect((await graph.toPrompt()).output).toEqual(getOutput()); expect((await graph.toPrompt()).output).toEqual(getOutput());
}); });
@ -383,6 +383,43 @@ describe("group node", () => {
getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id]) getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id])
); );
}); });
test("groups can connect to each other via internal reroutes", async () => {
const { ez, graph, app } = await start();
const latent = ez.EmptyLatentImage();
const vae = ez.VAELoader();
const latentReroute = ez.Reroute();
const vaeReroute = ez.Reroute();
latent.outputs[0].connectTo(latentReroute.inputs[0]);
vae.outputs[0].connectTo(vaeReroute.inputs[0]);
const group1 = await convertToGroup(app, graph, "test", [latentReroute, vaeReroute]);
group1.menu.Clone.call();
expect(app.graph._nodes).toHaveLength(4);
const group2 = graph.find(app.graph._nodes[3]);
expect(group2.node.type).toEqual("workflow/test");
expect(group2.id).not.toEqual(group1.id);
group1.outputs.VAE.connectTo(group2.inputs.VAE);
group1.outputs.LATENT.connectTo(group2.inputs.LATENT);
const decode = ez.VAEDecode(group2.outputs.LATENT, group2.outputs.VAE);
const preview = ez.PreviewImage(decode.outputs[0]);
const output = {
[latent.id]: { inputs: { width: 512, height: 512, batch_size: 1 }, class_type: "EmptyLatentImage" },
[vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" },
[decode.id]: { inputs: { samples: [latent.id + "", 0], vae: [vae.id + "", 0] }, class_type: "VAEDecode" },
[preview.id]: { inputs: { images: [decode.id + "", 0] }, class_type: "PreviewImage" },
};
expect((await graph.toPrompt()).output).toEqual(output);
// Ensure missing connections dont cause errors
group2.inputs.VAE.disconnect();
delete output[decode.id].inputs.vae;
expect((await graph.toPrompt()).output).toEqual(output);
});
test("displays generated image on group node", async () => { test("displays generated image on group node", async () => {
const { ez, graph, app } = await start(); const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph); const nodes = createDefaultWorkflow(ez, graph);
@ -642,6 +679,55 @@ describe("group node", () => {
2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" }, 2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
}); });
}); });
test("correctly handles widget inputs", async () => {
const { ez, graph, app } = await start();
const upscaleMethods = (await getNodeDef("ImageScaleBy")).input.required["upscale_method"][0];
const image = ez.LoadImage();
const scale1 = ez.ImageScaleBy(image.outputs[0]);
const scale2 = ez.ImageScaleBy(image.outputs[0]);
const preview1 = ez.PreviewImage(scale1.outputs[0]);
const preview2 = ez.PreviewImage(scale2.outputs[0]);
scale1.widgets.upscale_method.value = upscaleMethods[1];
scale1.widgets.upscale_method.convertToInput();
const group = await convertToGroup(app, graph, "test", [scale1, scale2]);
expect(group.inputs.length).toBe(3);
expect(group.inputs[0].input.type).toBe("IMAGE");
expect(group.inputs[1].input.type).toBe("IMAGE");
expect(group.inputs[2].input.type).toBe("COMBO");
// Ensure links are maintained
expect(group.inputs[0].connection?.originNode?.id).toBe(image.id);
expect(group.inputs[1].connection?.originNode?.id).toBe(image.id);
expect(group.inputs[2].connection).toBeFalsy();
// Ensure primitive gets correct type
const primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(group.inputs[2]);
expect(primitive.widgets.value.widget.options.values).toBe(upscaleMethods);
expect(primitive.widgets.value.value).toBe(upscaleMethods[1]); // Ensure value is copied
primitive.widgets.value.value = upscaleMethods[1];
await checkBeforeAndAfterReload(graph, async (r) => {
const scale1id = r ? `${group.id}:0` : scale1.id;
const scale2id = r ? `${group.id}:1` : scale2.id;
// Ensure widget value is applied to prompt
expect((await graph.toPrompt()).output).toStrictEqual({
[image.id]: { inputs: { image: "example.png", upload: "image" }, class_type: "LoadImage" },
[scale1id]: {
inputs: { upscale_method: upscaleMethods[1], scale_by: 1, image: [`${image.id}`, 0] },
class_type: "ImageScaleBy",
},
[scale2id]: {
inputs: { upscale_method: "nearest-exact", scale_by: 1, image: [`${image.id}`, 0] },
class_type: "ImageScaleBy",
},
[preview1.id]: { inputs: { images: [`${scale1id}`, 0] }, class_type: "PreviewImage" },
[preview2.id]: { inputs: { images: [`${scale2id}`, 0] }, class_type: "PreviewImage" },
});
});
});
test("adds widgets in node execution order", async () => { test("adds widgets in node execution order", async () => {
const { ez, graph, app } = await start(); const { ez, graph, app } = await start();
const scale = ez.LatentUpscale(); const scale = ez.LatentUpscale();
@ -815,4 +901,105 @@ describe("group node", () => {
expect(p2.widgets.control_after_generate.value).toBe("randomize"); expect(p2.widgets.control_after_generate.value).toBe("randomize");
expect(p2.widgets.control_filter_list.value).toBe("/.+/"); expect(p2.widgets.control_filter_list.value).toBe("/.+/");
}); });
test("internal reroutes work with converted inputs and merge options", async () => {
const { ez, graph, app } = await start();
const vae = ez.VAELoader();
const latent = ez.EmptyLatentImage();
const decode = ez.VAEDecode(latent.outputs.LATENT, vae.outputs.VAE);
const scale = ez.ImageScale(decode.outputs.IMAGE);
ez.PreviewImage(scale.outputs.IMAGE);
const r1 = ez.Reroute();
const r2 = ez.Reroute();
latent.widgets.width.value = 64;
latent.widgets.height.value = 128;
latent.widgets.width.convertToInput();
latent.widgets.height.convertToInput();
latent.widgets.batch_size.convertToInput();
scale.widgets.width.convertToInput();
scale.widgets.height.convertToInput();
r1.inputs[0].input.label = "hbw";
r1.outputs[0].connectTo(latent.inputs.height);
r1.outputs[0].connectTo(latent.inputs.batch_size);
r1.outputs[0].connectTo(scale.inputs.width);
r2.inputs[0].input.label = "wh";
r2.outputs[0].connectTo(latent.inputs.width);
r2.outputs[0].connectTo(scale.inputs.height);
const group = await convertToGroup(app, graph, "test", [r1, r2, latent, decode, scale]);
expect(group.inputs[0].input.type).toBe("VAE");
expect(group.inputs[1].input.type).toBe("INT");
expect(group.inputs[2].input.type).toBe("INT");
const p1 = ez.PrimitiveNode();
const p2 = ez.PrimitiveNode();
p1.outputs[0].connectTo(group.inputs[1]);
p2.outputs[0].connectTo(group.inputs[2]);
expect(p1.widgets.value.widget.options?.min).toBe(16); // width/height min
expect(p1.widgets.value.widget.options?.max).toBe(4096); // batch max
expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min
expect(p2.widgets.value.widget.options?.max).toBe(8192); // width/height max
expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
expect(p1.widgets.value.value).toBe(128);
expect(p2.widgets.value.value).toBe(64);
p1.widgets.value.value = 16;
p2.widgets.value.value = 32;
await checkBeforeAndAfterReload(graph, async (r) => {
const id = (v) => (r ? `${group.id}:` : "") + v;
expect((await graph.toPrompt()).output).toStrictEqual({
1: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" },
[id(2)]: { inputs: { width: 32, height: 16, batch_size: 16 }, class_type: "EmptyLatentImage" },
[id(3)]: { inputs: { samples: [id(2), 0], vae: ["1", 0] }, class_type: "VAEDecode" },
[id(4)]: {
inputs: { upscale_method: "nearest-exact", width: 16, height: 32, crop: "disabled", image: [id(3), 0] },
class_type: "ImageScale",
},
5: { inputs: { images: [id(4), 0] }, class_type: "PreviewImage" },
});
});
});
test("converted inputs with linked widgets map values correctly on creation", async () => {
const { ez, graph, app } = await start();
const k1 = ez.KSampler();
const k2 = ez.KSampler();
k1.widgets.seed.convertToInput();
k2.widgets.seed.convertToInput();
const rr = ez.Reroute();
rr.outputs[0].connectTo(k1.inputs.seed);
rr.outputs[0].connectTo(k2.inputs.seed);
const group = await convertToGroup(app, graph, "test", [k1, k2, rr]);
expect(group.widgets.steps.value).toBe(20);
expect(group.widgets.cfg.value).toBe(8);
expect(group.widgets.scheduler.value).toBe("normal");
expect(group.widgets["KSampler steps"].value).toBe(20);
expect(group.widgets["KSampler cfg"].value).toBe(8);
expect(group.widgets["KSampler scheduler"].value).toBe("normal");
});
test("allow multiple of the same node type to be added", async () => {
const { ez, graph, app } = await start();
const nodes = [...Array(10)].map(() => ez.ImageScaleBy());
const group = await convertToGroup(app, graph, "test", nodes);
expect(group.inputs.length).toBe(10);
expect(group.outputs.length).toBe(10);
expect(group.widgets.length).toBe(20);
expect(group.widgets.map((w) => w.widget.name)).toStrictEqual(
[...Array(10)]
.map((_, i) => `${i > 0 ? "ImageScaleBy " : ""}${i > 1 ? i + " " : ""}`)
.flatMap((p) => [`${p}upscale_method`, `${p}scale_by`])
);
});
}); });

View File

@ -1,7 +1,13 @@
// @ts-check // @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" /> /// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start, makeNodeDef, checkBeforeAndAfterReload, assertNotNullOrUndefined } = require("../utils"); const {
start,
makeNodeDef,
checkBeforeAndAfterReload,
assertNotNullOrUndefined,
createDefaultWorkflow,
} = require("../utils");
const lg = require("../utils/litegraph"); const lg = require("../utils/litegraph");
/** /**
@ -36,7 +42,7 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWi
if (controlWidgetCount) { if (controlWidgetCount) {
const controlWidget = primitive.widgets.control_after_generate; const controlWidget = primitive.widgets.control_after_generate;
expect(controlWidget.widget.type).toBe("combo"); expect(controlWidget.widget.type).toBe("combo");
if(widgetType === "combo") { if (widgetType === "combo") {
const filterWidget = primitive.widgets.control_filter_list; const filterWidget = primitive.widgets.control_filter_list;
expect(filterWidget.widget.type).toBe("string"); expect(filterWidget.widget.type).toBe("string");
} }
@ -308,8 +314,8 @@ describe("widget inputs", () => {
const { ez } = await start({ const { ez } = await start({
mockNodeDefs: { mockNodeDefs: {
...makeNodeDef("TestNode1", {}, [["A", "B"]]), ...makeNodeDef("TestNode1", {}, [["A", "B"]]),
...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true}] }), ...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true }] }),
...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true}] }), ...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true }] }),
}, },
}); });
@ -330,7 +336,7 @@ describe("widget inputs", () => {
const n1 = ez.TestNode1(); const n1 = ez.TestNode1();
n1.widgets.example.convertToInput(); n1.widgets.example.convertToInput();
const p = ez.PrimitiveNode() const p = ez.PrimitiveNode();
p.outputs[0].connectTo(n1.inputs[0]); p.outputs[0].connectTo(n1.inputs[0]);
const value = p.widgets.value; const value = p.widgets.value;
@ -380,7 +386,7 @@ describe("widget inputs", () => {
// Check random // Check random
control.value = "randomize"; control.value = "randomize";
filter.value = "/D/"; filter.value = "/D/";
for(let i = 0; i < 100; i++) { for (let i = 0; i < 100; i++) {
control["afterQueued"](); control["afterQueued"]();
expect(value.value === "D" || value.value === "DD").toBeTruthy(); expect(value.value === "D" || value.value === "DD").toBeTruthy();
} }
@ -392,4 +398,160 @@ describe("widget inputs", () => {
control["afterQueued"](); control["afterQueued"]();
expect(value.value).toBe("B"); expect(value.value).toBe("B");
}); });
describe("reroutes", () => {
async function checkOutput(graph, values) {
expect((await graph.toPrompt()).output).toStrictEqual({
1: { inputs: { ckpt_name: "model1.safetensors" }, class_type: "CheckpointLoaderSimple" },
2: { inputs: { text: "positive", clip: ["1", 1] }, class_type: "CLIPTextEncode" },
3: { inputs: { text: "negative", clip: ["1", 1] }, class_type: "CLIPTextEncode" },
4: {
inputs: { width: values.width ?? 512, height: values.height ?? 512, batch_size: values?.batch_size ?? 1 },
class_type: "EmptyLatentImage",
},
5: {
inputs: {
seed: 0,
steps: 20,
cfg: 8,
sampler_name: "euler",
scheduler: values?.scheduler ?? "normal",
denoise: 1,
model: ["1", 0],
positive: ["2", 0],
negative: ["3", 0],
latent_image: ["4", 0],
},
class_type: "KSampler",
},
6: { inputs: { samples: ["5", 0], vae: ["1", 2] }, class_type: "VAEDecode" },
7: {
inputs: { filename_prefix: values.filename_prefix ?? "ComfyUI", images: ["6", 0] },
class_type: "SaveImage",
},
});
}
async function waitForWidget(node) {
// widgets are created slightly after the graph is ready
// hard to find an exact hook to get these so just wait for them to be ready
for (let i = 0; i < 10; i++) {
await new Promise((r) => setTimeout(r, 10));
if (node.widgets?.value) {
return;
}
}
}
it("can connect primitive via a reroute path to a widget input", async () => {
const { ez, graph } = await start();
const nodes = createDefaultWorkflow(ez, graph);
nodes.empty.widgets.width.convertToInput();
nodes.sampler.widgets.scheduler.convertToInput();
nodes.save.widgets.filename_prefix.convertToInput();
let widthReroute = ez.Reroute();
let schedulerReroute = ez.Reroute();
let fileReroute = ez.Reroute();
let widthNext = widthReroute;
let schedulerNext = schedulerReroute;
let fileNext = fileReroute;
for (let i = 0; i < 5; i++) {
let next = ez.Reroute();
widthNext.outputs[0].connectTo(next.inputs[0]);
widthNext = next;
next = ez.Reroute();
schedulerNext.outputs[0].connectTo(next.inputs[0]);
schedulerNext = next;
next = ez.Reroute();
fileNext.outputs[0].connectTo(next.inputs[0]);
fileNext = next;
}
widthNext.outputs[0].connectTo(nodes.empty.inputs.width);
schedulerNext.outputs[0].connectTo(nodes.sampler.inputs.scheduler);
fileNext.outputs[0].connectTo(nodes.save.inputs.filename_prefix);
let widthPrimitive = ez.PrimitiveNode();
let schedulerPrimitive = ez.PrimitiveNode();
let filePrimitive = ez.PrimitiveNode();
widthPrimitive.outputs[0].connectTo(widthReroute.inputs[0]);
schedulerPrimitive.outputs[0].connectTo(schedulerReroute.inputs[0]);
filePrimitive.outputs[0].connectTo(fileReroute.inputs[0]);
expect(widthPrimitive.widgets.value.value).toBe(512);
widthPrimitive.widgets.value.value = 1024;
expect(schedulerPrimitive.widgets.value.value).toBe("normal");
schedulerPrimitive.widgets.value.value = "simple";
expect(filePrimitive.widgets.value.value).toBe("ComfyUI");
filePrimitive.widgets.value.value = "ComfyTest";
await checkBeforeAndAfterReload(graph, async () => {
widthPrimitive = graph.find(widthPrimitive);
schedulerPrimitive = graph.find(schedulerPrimitive);
filePrimitive = graph.find(filePrimitive);
await waitForWidget(filePrimitive);
expect(widthPrimitive.widgets.length).toBe(2);
expect(schedulerPrimitive.widgets.length).toBe(3);
expect(filePrimitive.widgets.length).toBe(1);
await checkOutput(graph, {
width: 1024,
scheduler: "simple",
filename_prefix: "ComfyTest",
});
});
});
it("can connect primitive via a reroute path to multiple widget inputs", async () => {
const { ez, graph } = await start();
const nodes = createDefaultWorkflow(ez, graph);
nodes.empty.widgets.width.convertToInput();
nodes.empty.widgets.height.convertToInput();
nodes.empty.widgets.batch_size.convertToInput();
let reroute = ez.Reroute();
let prevReroute = reroute;
for (let i = 0; i < 5; i++) {
const next = ez.Reroute();
prevReroute.outputs[0].connectTo(next.inputs[0]);
prevReroute = next;
}
const r1 = ez.Reroute(prevReroute.outputs[0]);
const r2 = ez.Reroute(prevReroute.outputs[0]);
const r3 = ez.Reroute(r2.outputs[0]);
const r4 = ez.Reroute(r2.outputs[0]);
r1.outputs[0].connectTo(nodes.empty.inputs.width);
r3.outputs[0].connectTo(nodes.empty.inputs.height);
r4.outputs[0].connectTo(nodes.empty.inputs.batch_size);
let primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(reroute.inputs[0]);
expect(primitive.widgets.value.value).toBe(1);
primitive.widgets.value.value = 64;
await checkBeforeAndAfterReload(graph, async (r) => {
primitive = graph.find(primitive);
await waitForWidget(primitive);
// Ensure widget configs are merged
expect(primitive.widgets.value.widget.options?.min).toBe(16); // width/height min
expect(primitive.widgets.value.widget.options?.max).toBe(4096); // batch max
expect(primitive.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
await checkOutput(graph, {
width: 64,
height: 64,
batch_size: 64,
});
});
});
});
}); });

View File

@ -78,6 +78,14 @@ export class EzInput extends EzSlot {
this.input = input; this.input = input;
} }
get connection() {
const link = this.node.node.inputs?.[this.index]?.link;
if (link == null) {
return null;
}
return new EzConnection(this.node.app, this.node.app.graph.links[link]);
}
disconnect() { disconnect() {
this.node.node.disconnectInput(this.index); this.node.node.disconnectInput(this.index);
} }
@ -117,7 +125,7 @@ export class EzOutput extends EzSlot {
const inp = input.input; const inp = input.input;
const inName = inp.name || inp.label || inp.type; const inName = inp.name || inp.label || inp.type;
throw new Error( throw new Error(
`Connecting from ${input.node.node.type}[${inName}#${input.index}] -> ${this.node.node.type}[${ `Connecting from ${input.node.node.type}#${input.node.id}[${inName}#${input.index}] -> ${this.node.node.type}#${this.node.id}[${
this.output.name ?? this.output.type this.output.name ?? this.output.type
}#${this.index}] failed.` }#${this.index}] failed.`
); );
@ -179,6 +187,7 @@ export class EzWidget {
set value(v) { set value(v) {
this.widget.value = v; this.widget.value = v;
this.widget.callback?.call?.(this.widget, v)
} }
get isConvertedToInput() { get isConvertedToInput() {
@ -319,7 +328,7 @@ export class EzGraph {
} }
stringify() { stringify() {
return JSON.stringify(this.app.graph.serialize(), undefined, "\t"); return JSON.stringify(this.app.graph.serialize(), undefined);
} }
/** /**

View File

@ -104,3 +104,12 @@ export function createDefaultWorkflow(ez, graph) {
return { ckpt, pos, neg, empty, sampler, decode, save }; return { ckpt, pos, neg, empty, sampler, decode, save };
} }
export async function getNodeDefs() {
const { api } = require("../../web/scripts/api");
return api.getNodeDefs();
}
export async function getNodeDef(nodeId) {
return (await getNodeDefs())[nodeId];
}

View File

@ -174,6 +174,11 @@ export class GroupNodeConfig {
node.index = i; node.index = i;
this.processNode(node, seenInputs, seenOutputs); this.processNode(node, seenInputs, seenOutputs);
} }
for (const p of this.#convertedToProcess) {
p();
}
this.#convertedToProcess = null;
await app.registerNodeDef("workflow/" + this.name, this.nodeDef); await app.registerNodeDef("workflow/" + this.name, this.nodeDef);
} }
@ -192,7 +197,10 @@ export class GroupNodeConfig {
if (!this.linksFrom[sourceNodeId]) { if (!this.linksFrom[sourceNodeId]) {
this.linksFrom[sourceNodeId] = {}; this.linksFrom[sourceNodeId] = {};
} }
this.linksFrom[sourceNodeId][sourceNodeSlot] = l; if (!this.linksFrom[sourceNodeId][sourceNodeSlot]) {
this.linksFrom[sourceNodeId][sourceNodeSlot] = [];
}
this.linksFrom[sourceNodeId][sourceNodeSlot].push(l);
if (!this.linksTo[targetNodeId]) { if (!this.linksTo[targetNodeId]) {
this.linksTo[targetNodeId] = {}; this.linksTo[targetNodeId] = {};
@ -230,11 +238,11 @@ export class GroupNodeConfig {
// Skip as its not linked // Skip as its not linked
if (!linksFrom) return; if (!linksFrom) return;
let type = linksFrom["0"][5]; let type = linksFrom["0"][0][5];
if (type === "COMBO") { if (type === "COMBO") {
// Use the array items // Use the array items
const source = node.outputs[0].widget.name; const source = node.outputs[0].widget.name;
const fromTypeName = this.nodeData.nodes[linksFrom["0"][2]].type; const fromTypeName = this.nodeData.nodes[linksFrom["0"][0][2]].type;
const fromType = globalDefs[fromTypeName]; const fromType = globalDefs[fromTypeName];
const input = fromType.input.required[source] ?? fromType.input.optional[source]; const input = fromType.input.required[source] ?? fromType.input.optional[source];
type = input[0]; type = input[0];
@ -258,10 +266,33 @@ export class GroupNodeConfig {
return null; return null;
} }
let config = {};
let rerouteType = "*"; let rerouteType = "*";
if (linksFrom) { if (linksFrom) {
const [, , id, slot] = linksFrom["0"]; for (const [, , id, slot] of linksFrom["0"]) {
rerouteType = this.nodeData.nodes[id].inputs[slot].type; const node = this.nodeData.nodes[id];
const input = node.inputs[slot];
if (rerouteType === "*") {
rerouteType = input.type;
}
if (input.widget) {
const targetDef = globalDefs[node.type];
const targetWidget =
targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name];
const widget = [targetWidget[0], config];
const res = mergeIfValid(
{
widget,
},
targetWidget,
false,
null,
widget
);
config = res?.customConfig ?? config;
}
}
} else if (linksTo) { } else if (linksTo) {
const [id, slot] = linksTo["0"]; const [id, slot] = linksTo["0"];
rerouteType = this.nodeData.nodes[id].outputs[slot].type; rerouteType = this.nodeData.nodes[id].outputs[slot].type;
@ -282,10 +313,11 @@ export class GroupNodeConfig {
} }
} }
config.forceInput = true;
return { return {
input: { input: {
required: { required: {
[rerouteType]: [rerouteType, {}], [rerouteType]: [rerouteType, config],
}, },
}, },
output: [rerouteType], output: [rerouteType],
@ -299,16 +331,17 @@ export class GroupNodeConfig {
getInputConfig(node, inputName, seenInputs, config, extra) { getInputConfig(node, inputName, seenInputs, config, extra) {
let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName; let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName;
let key = name;
let prefix = ""; let prefix = "";
// Special handling for primitive to include the title if it is set rather than just "value" // Special handling for primitive to include the title if it is set rather than just "value"
if ((node.type === "PrimitiveNode" && node.title) || name in seenInputs) { if ((node.type === "PrimitiveNode" && node.title) || name in seenInputs) {
prefix = `${node.title ?? node.type} `; prefix = `${node.title ?? node.type} `;
name = `${prefix}${inputName}`; key = name = `${prefix}${inputName}`;
if (name in seenInputs) { if (name in seenInputs) {
name = `${prefix}${seenInputs[name]} ${inputName}`; name = `${prefix}${seenInputs[name]} ${inputName}`;
} }
} }
seenInputs[name] = (seenInputs[name] ?? 1) + 1; seenInputs[key] = (seenInputs[key] ?? 1) + 1;
if (inputName === "seed" || inputName === "noise_seed") { if (inputName === "seed" || inputName === "noise_seed") {
if (!extra) extra = {}; if (!extra) extra = {};
@ -420,10 +453,18 @@ export class GroupNodeConfig {
defaultInput: true, defaultInput: true,
}); });
this.nodeDef.input.required[name] = config; this.nodeDef.input.required[name] = config;
this.newToOldWidgetMap[name] = { node, inputName };
if (!this.oldToNewWidgetMap[node.index]) {
this.oldToNewWidgetMap[node.index] = {};
}
this.oldToNewWidgetMap[node.index][inputName] = name;
inputMap[slots.length + i] = this.inputCount++; inputMap[slots.length + i] = this.inputCount++;
} }
} }
#convertedToProcess = [];
processNodeInputs(node, seenInputs, inputs) { processNodeInputs(node, seenInputs, inputs) {
const inputMapping = []; const inputMapping = [];
@ -434,7 +475,11 @@ export class GroupNodeConfig {
const linksTo = this.linksTo[node.index] ?? {}; const linksTo = this.linksTo[node.index] ?? {};
const inputMap = (this.oldToNewInputMap[node.index] = {}); const inputMap = (this.oldToNewInputMap[node.index] = {});
this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs); this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs);
this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs);
// Converted inputs have to be processed after all other nodes as they'll be at the end of the list
this.#convertedToProcess.push(() =>
this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs)
);
return inputMapping; return inputMapping;
} }
@ -597,11 +642,19 @@ export class GroupNodeHandler {
const output = this.groupData.newToOldOutputMap[link.origin_slot]; const output = this.groupData.newToOldOutputMap[link.origin_slot];
let innerNode = this.innerNodes[output.node.index]; let innerNode = this.innerNodes[output.node.index];
let l; let l;
while (innerNode.type === "Reroute") { while (innerNode?.type === "Reroute") {
l = innerNode.getInputLink(0); l = innerNode.getInputLink(0);
innerNode = innerNode.getInputNode(0); innerNode = innerNode.getInputNode(0);
} }
if (!innerNode) {
return null;
}
if (l && GroupNodeHandler.isGroupNode(innerNode)) {
return innerNode.updateLink(l);
}
link.origin_id = innerNode.id; link.origin_id = innerNode.id;
link.origin_slot = l?.origin_slot ?? output.slot; link.origin_slot = l?.origin_slot ?? output.slot;
return link; return link;
@ -665,6 +718,8 @@ export class GroupNodeHandler {
top = newNode.pos[1]; top = newNode.pos[1];
} }
if (!newNode.widgets) continue;
const map = this.groupData.oldToNewWidgetMap[innerNode.index]; const map = this.groupData.oldToNewWidgetMap[innerNode.index];
if (map) { if (map) {
const widgets = Object.keys(map); const widgets = Object.keys(map);
@ -721,7 +776,7 @@ export class GroupNodeHandler {
} }
}; };
const reconnectOutputs = () => { const reconnectOutputs = (selectedIds) => {
for (let groupOutputId = 0; groupOutputId < node.outputs?.length; groupOutputId++) { for (let groupOutputId = 0; groupOutputId < node.outputs?.length; groupOutputId++) {
const output = node.outputs[groupOutputId]; const output = node.outputs[groupOutputId];
if (!output.links) continue; if (!output.links) continue;
@ -861,7 +916,7 @@ export class GroupNodeHandler {
if (innerNode.type === "PrimitiveNode") { if (innerNode.type === "PrimitiveNode") {
innerNode.primitiveValue = newValue; innerNode.primitiveValue = newValue;
const primitiveLinked = this.groupData.primitiveToWidget[old.node.index]; const primitiveLinked = this.groupData.primitiveToWidget[old.node.index];
for (const linked of primitiveLinked) { for (const linked of primitiveLinked ?? []) {
const node = this.innerNodes[linked.nodeId]; const node = this.innerNodes[linked.nodeId];
const widget = node.widgets.find((w) => w.name === linked.inputName); const widget = node.widgets.find((w) => w.name === linked.inputName);
@ -870,6 +925,18 @@ export class GroupNodeHandler {
} }
} }
continue; continue;
} else if (innerNode.type === "Reroute") {
const rerouteLinks = this.groupData.linksFrom[old.node.index];
for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) {
const node = this.innerNodes[targetNodeId];
const input = node.inputs[targetSlot];
if (input.widget) {
const widget = node.widgets?.find((w) => w.name === input.widget.name);
if (widget) {
widget.value = newValue;
}
}
}
} }
const widget = innerNode.widgets?.find((w) => w.name === old.inputName); const widget = innerNode.widgets?.find((w) => w.name === old.inputName);
@ -897,33 +964,58 @@ export class GroupNodeHandler {
this.node.widgets[targetWidgetIndex + i].value = primitiveNode.widgets[i].value; this.node.widgets[targetWidgetIndex + i].value = primitiveNode.widgets[i].value;
} }
} }
return true;
} }
populateReroute(node, nodeId, map) {
if (node.type !== "Reroute") return;
const link = this.groupData.linksFrom[nodeId]?.[0]?.[0];
if (!link) return;
const [, , targetNodeId, targetNodeSlot] = link;
const targetNode = this.groupData.nodeData.nodes[targetNodeId];
const inputs = targetNode.inputs;
const targetWidget = inputs?.[targetNodeSlot].widget;
if (!targetWidget) return;
const offset = inputs.length - (targetNode.widgets_values?.length ?? 0);
const v = targetNode.widgets_values?.[targetNodeSlot - offset];
if (v == null) return;
const widgetName = Object.values(map)[0];
const widget = this.node.widgets.find(w => w.name === widgetName);
if(widget) {
widget.value = v;
}
}
populateWidgets() { populateWidgets() {
if (!this.node.widgets) return;
for (let nodeId = 0; nodeId < this.groupData.nodeData.nodes.length; nodeId++) { for (let nodeId = 0; nodeId < this.groupData.nodeData.nodes.length; nodeId++) {
const node = this.groupData.nodeData.nodes[nodeId]; const node = this.groupData.nodeData.nodes[nodeId];
const map = this.groupData.oldToNewWidgetMap[nodeId] ?? {};
if (!node.widgets_values?.length) continue;
const map = this.groupData.oldToNewWidgetMap[nodeId];
const widgets = Object.keys(map); const widgets = Object.keys(map);
if (!node.widgets_values?.length) {
// special handling for populating values into reroutes
// this allows primitives connect to them to pick up the correct value
this.populateReroute(node, nodeId, map);
continue;
}
let linkedShift = 0; let linkedShift = 0;
for (let i = 0; i < widgets.length; i++) { for (let i = 0; i < widgets.length; i++) {
const oldName = widgets[i]; const oldName = widgets[i];
const newName = map[oldName]; const newName = map[oldName];
const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName); const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName);
const mainWidget = this.node.widgets[widgetIndex]; const mainWidget = this.node.widgets[widgetIndex];
if (!newName) { if (this.populatePrimitive(node, nodeId, oldName, i, linkedShift) || widgetIndex === -1) {
// New name will be null if its a converted widget
this.populatePrimitive(node, nodeId, oldName, i, linkedShift);
// Find the inner widget and shift by the number of linked widgets as they will have been removed too // Find the inner widget and shift by the number of linked widgets as they will have been removed too
const innerWidget = this.innerNodes[nodeId].widgets?.find((w) => w.name === oldName); const innerWidget = this.innerNodes[nodeId].widgets?.find((w) => w.name === oldName);
linkedShift += innerWidget.linkedWidgets?.length ?? 0; linkedShift += innerWidget?.linkedWidgets?.length ?? 0;
continue;
} }
if (widgetIndex === -1) { if (widgetIndex === -1) {
continue; continue;
} }

View File

@ -33,6 +33,18 @@ function loadedImageToBlob(image) {
return blob; return blob;
} }
function loadImage(imagePath) {
return new Promise((resolve, reject) => {
const image = new Image();
image.onload = function() {
resolve(image);
};
image.src = imagePath;
});
}
async function uploadMask(filepath, formData) { async function uploadMask(filepath, formData) {
await api.fetchApi('/upload/mask', { await api.fetchApi('/upload/mask', {
method: 'POST', method: 'POST',
@ -50,25 +62,25 @@ async function uploadMask(filepath, formData) {
ClipspaceDialog.invalidatePreview(); ClipspaceDialog.invalidatePreview();
} }
function prepareRGB(image, backupCanvas, backupCtx) { function prepare_mask(image, maskCanvas, maskCtx) {
// paste mask data into alpha channel // paste mask data into alpha channel
backupCtx.drawImage(image, 0, 0, backupCanvas.width, backupCanvas.height); maskCtx.drawImage(image, 0, 0, maskCanvas.width, maskCanvas.height);
const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height); const maskData = maskCtx.getImageData(0, 0, maskCanvas.width, maskCanvas.height);
// refine mask image // invert mask
for (let i = 0; i < backupData.data.length; i += 4) { for (let i = 0; i < maskData.data.length; i += 4) {
if(backupData.data[i+3] == 255) if(maskData.data[i+3] == 255)
backupData.data[i+3] = 0; maskData.data[i+3] = 0;
else else
backupData.data[i+3] = 255; maskData.data[i+3] = 255;
backupData.data[i] = 0; maskData.data[i] = 0;
backupData.data[i+1] = 0; maskData.data[i+1] = 0;
backupData.data[i+2] = 0; maskData.data[i+2] = 0;
} }
backupCtx.globalCompositeOperation = 'source-over'; maskCtx.globalCompositeOperation = 'source-over';
backupCtx.putImageData(backupData, 0, 0); maskCtx.putImageData(maskData, 0, 0);
} }
class MaskEditorDialog extends ComfyDialog { class MaskEditorDialog extends ComfyDialog {
@ -155,10 +167,6 @@ class MaskEditorDialog extends ComfyDialog {
// If it is specified as relative, using it only as a hidden placeholder for padding is recommended // If it is specified as relative, using it only as a hidden placeholder for padding is recommended
// to prevent anomalies where it exceeds a certain size and goes outside of the window. // to prevent anomalies where it exceeds a certain size and goes outside of the window.
var placeholder = document.createElement("div");
placeholder.style.position = "relative";
placeholder.style.height = "50px";
var bottom_panel = document.createElement("div"); var bottom_panel = document.createElement("div");
bottom_panel.style.position = "absolute"; bottom_panel.style.position = "absolute";
bottom_panel.style.bottom = "0px"; bottom_panel.style.bottom = "0px";
@ -180,18 +188,16 @@ class MaskEditorDialog extends ComfyDialog {
this.brush = brush; this.brush = brush;
this.element.appendChild(imgCanvas); this.element.appendChild(imgCanvas);
this.element.appendChild(maskCanvas); this.element.appendChild(maskCanvas);
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
this.element.appendChild(bottom_panel); this.element.appendChild(bottom_panel);
document.body.appendChild(brush); document.body.appendChild(brush);
var brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => { this.brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => {
self.brush_size = event.target.value; self.brush_size = event.target.value;
self.updateBrushPreview(self, null, null); self.updateBrushPreview(self, null, null);
}); });
var clearButton = this.createLeftButton("Clear", var clearButton = this.createLeftButton("Clear",
() => { () => {
self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height); self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height);
self.backupCtx.clearRect(0, 0, self.backupCanvas.width, self.backupCanvas.height);
}); });
var cancelButton = this.createRightButton("Cancel", () => { var cancelButton = this.createRightButton("Cancel", () => {
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
@ -207,40 +213,42 @@ class MaskEditorDialog extends ComfyDialog {
this.element.appendChild(imgCanvas); this.element.appendChild(imgCanvas);
this.element.appendChild(maskCanvas); this.element.appendChild(maskCanvas);
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
this.element.appendChild(bottom_panel); this.element.appendChild(bottom_panel);
bottom_panel.appendChild(clearButton); bottom_panel.appendChild(clearButton);
bottom_panel.appendChild(this.saveButton); bottom_panel.appendChild(this.saveButton);
bottom_panel.appendChild(cancelButton); bottom_panel.appendChild(cancelButton);
bottom_panel.appendChild(brush_size_slider); bottom_panel.appendChild(this.brush_size_slider);
imgCanvas.style.position = "absolute";
maskCanvas.style.position = "absolute";
imgCanvas.style.position = "relative";
imgCanvas.style.top = "200"; imgCanvas.style.top = "200";
imgCanvas.style.left = "0"; imgCanvas.style.left = "0";
maskCanvas.style.position = "absolute"; maskCanvas.style.top = imgCanvas.style.top;
maskCanvas.style.left = imgCanvas.style.left;
} }
show() { async show() {
this.zoom_ratio = 1.0;
this.pan_x = 0;
this.pan_y = 0;
if(!this.is_layout_created) { if(!this.is_layout_created) {
// layout // layout
const imgCanvas = document.createElement('canvas'); const imgCanvas = document.createElement('canvas');
const maskCanvas = document.createElement('canvas'); const maskCanvas = document.createElement('canvas');
const backupCanvas = document.createElement('canvas');
imgCanvas.id = "imageCanvas"; imgCanvas.id = "imageCanvas";
maskCanvas.id = "maskCanvas"; maskCanvas.id = "maskCanvas";
backupCanvas.id = "backupCanvas";
this.setlayout(imgCanvas, maskCanvas); this.setlayout(imgCanvas, maskCanvas);
// prepare content // prepare content
this.imgCanvas = imgCanvas; this.imgCanvas = imgCanvas;
this.maskCanvas = maskCanvas; this.maskCanvas = maskCanvas;
this.backupCanvas = backupCanvas; this.maskCtx = maskCanvas.getContext('2d', {willReadFrequently: true });
this.maskCtx = maskCanvas.getContext('2d');
this.backupCtx = backupCanvas.getContext('2d');
this.setEventHandler(maskCanvas); this.setEventHandler(maskCanvas);
@ -252,6 +260,8 @@ class MaskEditorDialog extends ComfyDialog {
mutations.forEach(function(mutation) { mutations.forEach(function(mutation) {
if (mutation.type === 'attributes' && mutation.attributeName === 'style') { if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') { if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') {
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
self.brush.style.display = "none";
ComfyApp.onClipspaceEditorClosed(); ComfyApp.onClipspaceEditorClosed();
} }
@ -264,7 +274,8 @@ class MaskEditorDialog extends ComfyDialog {
observer.observe(this.element, config); observer.observe(this.element, config);
} }
this.setImages(this.imgCanvas, this.backupCanvas); // The keydown event needs to be reconfigured when closing the dialog as it gets removed.
document.addEventListener('keydown', MaskEditorDialog.handleKeyDown);
if(ComfyApp.clipspace_return_node) { if(ComfyApp.clipspace_return_node) {
this.saveButton.innerText = "Save to node"; this.saveButton.innerText = "Save to node";
@ -275,97 +286,157 @@ class MaskEditorDialog extends ComfyDialog {
this.saveButton.disabled = false; this.saveButton.disabled = false;
this.element.style.display = "block"; this.element.style.display = "block";
this.element.style.width = "85%";
this.element.style.margin = "0 7.5%";
this.element.style.height = "100vh";
this.element.style.top = "50%";
this.element.style.left = "42%";
this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority. this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority.
await this.setImages(this.imgCanvas);
this.is_visible = true;
} }
isOpened() { isOpened() {
return this.element.style.display == "block"; return this.element.style.display == "block";
} }
setImages(imgCanvas, backupCanvas) { invalidateCanvas(orig_image, mask_image) {
const imgCtx = imgCanvas.getContext('2d'); this.imgCanvas.width = orig_image.width;
const backupCtx = backupCanvas.getContext('2d'); this.imgCanvas.height = orig_image.height;
this.maskCanvas.width = orig_image.width;
this.maskCanvas.height = orig_image.height;
let imgCtx = this.imgCanvas.getContext('2d', {willReadFrequently: true });
let maskCtx = this.maskCanvas.getContext('2d', {willReadFrequently: true });
imgCtx.drawImage(orig_image, 0, 0, orig_image.width, orig_image.height);
prepare_mask(mask_image, this.maskCanvas, maskCtx);
}
async setImages(imgCanvas) {
let self = this;
const imgCtx = imgCanvas.getContext('2d', {willReadFrequently: true });
const maskCtx = this.maskCtx; const maskCtx = this.maskCtx;
const maskCanvas = this.maskCanvas; const maskCanvas = this.maskCanvas;
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height); imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height);
maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height); maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height);
// image load // image load
const orig_image = new Image();
window.addEventListener("resize", () => {
// repositioning
imgCanvas.width = window.innerWidth - 250;
imgCanvas.height = window.innerHeight - 200;
// redraw image
let drawWidth = orig_image.width;
let drawHeight = orig_image.height;
if (orig_image.width > imgCanvas.width) {
drawWidth = imgCanvas.width;
drawHeight = (drawWidth / orig_image.width) * orig_image.height;
}
if (drawHeight > imgCanvas.height) {
drawHeight = imgCanvas.height;
drawWidth = (drawHeight / orig_image.height) * orig_image.width;
}
imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight);
// update mask
maskCanvas.width = drawWidth;
maskCanvas.height = drawHeight;
maskCanvas.style.top = imgCanvas.offsetTop + "px";
maskCanvas.style.left = imgCanvas.offsetLeft + "px";
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height);
});
const filepath = ComfyApp.clipspace.images; const filepath = ComfyApp.clipspace.images;
const touched_image = new Image();
touched_image.onload = function() {
backupCanvas.width = touched_image.width;
backupCanvas.height = touched_image.height;
prepareRGB(touched_image, backupCanvas, backupCtx);
};
const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src) const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src)
alpha_url.searchParams.delete('channel'); alpha_url.searchParams.delete('channel');
alpha_url.searchParams.delete('preview'); alpha_url.searchParams.delete('preview');
alpha_url.searchParams.set('channel', 'a'); alpha_url.searchParams.set('channel', 'a');
touched_image.src = alpha_url; let mask_image = await loadImage(alpha_url);
// original image load // original image load
orig_image.onload = function() {
window.dispatchEvent(new Event('resize'));
};
const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src); const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src);
rgb_url.searchParams.delete('channel'); rgb_url.searchParams.delete('channel');
rgb_url.searchParams.set('channel', 'rgb'); rgb_url.searchParams.set('channel', 'rgb');
orig_image.src = rgb_url; this.image = new Image();
this.image = orig_image; this.image.onload = function() {
maskCanvas.width = self.image.width;
maskCanvas.height = self.image.height;
self.invalidateCanvas(self.image, mask_image);
self.initializeCanvasPanZoom();
};
this.image.src = rgb_url;
} }
setEventHandler(maskCanvas) { initializeCanvasPanZoom() {
maskCanvas.addEventListener("contextmenu", (event) => { // set initialize
event.preventDefault(); let drawWidth = this.image.width;
}); let drawHeight = this.image.height;
let width = this.element.clientWidth;
let height = this.element.clientHeight;
if (this.image.width > width) {
drawWidth = width;
drawHeight = (drawWidth / this.image.width) * this.image.height;
}
if (drawHeight > height) {
drawHeight = height;
drawWidth = (drawHeight / this.image.height) * this.image.width;
}
this.zoom_ratio = drawWidth/this.image.width;
const canvasX = (width - drawWidth) / 2;
const canvasY = (height - drawHeight) / 2;
this.pan_x = canvasX;
this.pan_y = canvasY;
this.invalidatePanZoom();
}
invalidatePanZoom() {
let raw_width = this.image.width * this.zoom_ratio;
let raw_height = this.image.height * this.zoom_ratio;
if(this.pan_x + raw_width < 10) {
this.pan_x = 10 - raw_width;
}
if(this.pan_y + raw_height < 10) {
this.pan_y = 10 - raw_height;
}
let width = `${raw_width}px`;
let height = `${raw_height}px`;
let left = `${this.pan_x}px`;
let top = `${this.pan_y}px`;
this.maskCanvas.style.width = width;
this.maskCanvas.style.height = height;
this.maskCanvas.style.left = left;
this.maskCanvas.style.top = top;
this.imgCanvas.style.width = width;
this.imgCanvas.style.height = height;
this.imgCanvas.style.left = left;
this.imgCanvas.style.top = top;
}
setEventHandler(maskCanvas) {
const self = this; const self = this;
maskCanvas.addEventListener('wheel', (event) => this.handleWheelEvent(self,event));
maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event)); if(!this.handler_registered) {
document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp); maskCanvas.addEventListener("contextmenu", (event) => {
maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event)); event.preventDefault();
maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event)); });
maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; });
maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; }); this.element.addEventListener('wheel', (event) => this.handleWheelEvent(self,event));
document.addEventListener('keydown', MaskEditorDialog.handleKeyDown); this.element.addEventListener('pointermove', (event) => this.pointMoveEvent(self,event));
this.element.addEventListener('touchmove', (event) => this.pointMoveEvent(self,event));
this.element.addEventListener('dragstart', (event) => {
if(event.ctrlKey) {
event.preventDefault();
}
});
maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event));
maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event));
maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event));
maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; });
maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; });
document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp);
this.handler_registered = true;
}
} }
brush_size = 10; brush_size = 10;
@ -378,8 +449,10 @@ class MaskEditorDialog extends ComfyDialog {
const self = MaskEditorDialog.instance; const self = MaskEditorDialog.instance;
if (event.key === ']') { if (event.key === ']') {
self.brush_size = Math.min(self.brush_size+2, 100); self.brush_size = Math.min(self.brush_size+2, 100);
self.brush_slider_input.value = self.brush_size;
} else if (event.key === '[') { } else if (event.key === '[') {
self.brush_size = Math.max(self.brush_size-2, 1); self.brush_size = Math.max(self.brush_size-2, 1);
self.brush_slider_input.value = self.brush_size;
} else if(event.key === 'Enter') { } else if(event.key === 'Enter') {
self.save(); self.save();
} }
@ -389,6 +462,10 @@ class MaskEditorDialog extends ComfyDialog {
static handlePointerUp(event) { static handlePointerUp(event) {
event.preventDefault(); event.preventDefault();
this.mousedown_x = null;
this.mousedown_y = null;
MaskEditorDialog.instance.drawing_mode = false; MaskEditorDialog.instance.drawing_mode = false;
} }
@ -398,24 +475,70 @@ class MaskEditorDialog extends ComfyDialog {
var centerX = self.cursorX; var centerX = self.cursorX;
var centerY = self.cursorY; var centerY = self.cursorY;
brush.style.width = self.brush_size * 2 + "px"; brush.style.width = self.brush_size * 2 * this.zoom_ratio + "px";
brush.style.height = self.brush_size * 2 + "px"; brush.style.height = self.brush_size * 2 * this.zoom_ratio + "px";
brush.style.left = (centerX - self.brush_size) + "px"; brush.style.left = (centerX - self.brush_size * this.zoom_ratio) + "px";
brush.style.top = (centerY - self.brush_size) + "px"; brush.style.top = (centerY - self.brush_size * this.zoom_ratio) + "px";
} }
handleWheelEvent(self, event) { handleWheelEvent(self, event) {
if(event.deltaY < 0) event.preventDefault();
self.brush_size = Math.min(self.brush_size+2, 100);
else
self.brush_size = Math.max(self.brush_size-2, 1);
self.brush_slider_input.value = self.brush_size; if(event.ctrlKey) {
// zoom canvas
if(event.deltaY < 0) {
this.zoom_ratio = Math.min(10.0, this.zoom_ratio+0.2);
}
else {
this.zoom_ratio = Math.max(0.2, this.zoom_ratio-0.2);
}
this.invalidatePanZoom();
}
else {
// adjust brush size
if(event.deltaY < 0)
this.brush_size = Math.min(this.brush_size+2, 100);
else
this.brush_size = Math.max(this.brush_size-2, 1);
this.brush_slider_input.value = this.brush_size;
this.updateBrushPreview(this);
}
}
pointMoveEvent(self, event) {
this.cursorX = event.pageX;
this.cursorY = event.pageY;
self.updateBrushPreview(self); self.updateBrushPreview(self);
if(event.ctrlKey) {
event.preventDefault();
self.pan_move(self, event);
}
}
pan_move(self, event) {
if(event.buttons == 1) {
if(this.mousedown_x) {
let deltaX = this.mousedown_x - event.clientX;
let deltaY = this.mousedown_y - event.clientY;
self.pan_x = this.mousedown_pan_x - deltaX;
self.pan_y = this.mousedown_pan_y - deltaY;
self.invalidatePanZoom();
}
}
} }
draw_move(self, event) { draw_move(self, event) {
if(event.ctrlKey) {
return;
}
event.preventDefault(); event.preventDefault();
this.cursorX = event.pageX; this.cursorX = event.pageX;
@ -439,6 +562,9 @@ class MaskEditorDialog extends ComfyDialog {
y = event.targetTouches[0].clientY - maskRect.top; y = event.targetTouches[0].clientY - maskRect.top;
} }
x /= self.zoom_ratio;
y /= self.zoom_ratio;
var brush_size = this.brush_size; var brush_size = this.brush_size;
if(event instanceof PointerEvent && event.pointerType == 'pen') { if(event instanceof PointerEvent && event.pointerType == 'pen') {
brush_size *= event.pressure; brush_size *= event.pressure;
@ -489,8 +615,8 @@ class MaskEditorDialog extends ComfyDialog {
} }
else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) { else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) {
const maskRect = self.maskCanvas.getBoundingClientRect(); const maskRect = self.maskCanvas.getBoundingClientRect();
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left; const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio;
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top; const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio;
var brush_size = this.brush_size; var brush_size = this.brush_size;
if(event instanceof PointerEvent && event.pointerType == 'pen') { if(event instanceof PointerEvent && event.pointerType == 'pen') {
@ -540,6 +666,17 @@ class MaskEditorDialog extends ComfyDialog {
} }
handlePointerDown(self, event) { handlePointerDown(self, event) {
if(event.ctrlKey) {
if (event.buttons == 1) {
this.mousedown_x = event.clientX;
this.mousedown_y = event.clientY;
this.mousedown_pan_x = this.pan_x;
this.mousedown_pan_y = this.pan_y;
}
return;
}
var brush_size = this.brush_size; var brush_size = this.brush_size;
if(event instanceof PointerEvent && event.pointerType == 'pen') { if(event instanceof PointerEvent && event.pointerType == 'pen') {
brush_size *= event.pressure; brush_size *= event.pressure;
@ -551,8 +688,8 @@ class MaskEditorDialog extends ComfyDialog {
event.preventDefault(); event.preventDefault();
const maskRect = self.maskCanvas.getBoundingClientRect(); const maskRect = self.maskCanvas.getBoundingClientRect();
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left; const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio;
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top; const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio;
self.maskCtx.beginPath(); self.maskCtx.beginPath();
if (event.button == 0) { if (event.button == 0) {
@ -570,15 +707,18 @@ class MaskEditorDialog extends ComfyDialog {
} }
async save() { async save() {
const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true}); const backupCanvas = document.createElement('canvas');
const backupCtx = backupCanvas.getContext('2d', {willReadFrequently:true});
backupCanvas.width = this.image.width;
backupCanvas.height = this.image.height;
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); backupCtx.clearRect(0,0, backupCanvas.width, backupCanvas.height);
backupCtx.drawImage(this.maskCanvas, backupCtx.drawImage(this.maskCanvas,
0, 0, this.maskCanvas.width, this.maskCanvas.height, 0, 0, this.maskCanvas.width, this.maskCanvas.height,
0, 0, this.backupCanvas.width, this.backupCanvas.height); 0, 0, backupCanvas.width, backupCanvas.height);
// paste mask data into alpha channel // paste mask data into alpha channel
const backupData = backupCtx.getImageData(0, 0, this.backupCanvas.width, this.backupCanvas.height); const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height);
// refine mask image // refine mask image
for (let i = 0; i < backupData.data.length; i += 4) { for (let i = 0; i < backupData.data.length; i += 4) {
@ -615,7 +755,7 @@ class MaskEditorDialog extends ComfyDialog {
ComfyApp.clipspace.widgets[index].value = item; ComfyApp.clipspace.widgets[index].value = item;
} }
const dataURL = this.backupCanvas.toDataURL(); const dataURL = backupCanvas.toDataURL();
const blob = dataURLToBlob(dataURL); const blob = dataURLToBlob(dataURL);
let original_url = new URL(this.image.src); let original_url = new URL(this.image.src);

View File

@ -1,10 +1,11 @@
import { app } from "../../scripts/app.js"; import { app } from "../../scripts/app.js";
import { mergeIfValid, getWidgetConfig, setWidgetConfig } from "./widgetInputs.js";
// Node that allows you to redirect connections for cleaner graphs // Node that allows you to redirect connections for cleaner graphs
app.registerExtension({ app.registerExtension({
name: "Comfy.RerouteNode", name: "Comfy.RerouteNode",
registerCustomNodes() { registerCustomNodes(app) {
class RerouteNode { class RerouteNode {
constructor() { constructor() {
if (!this.properties) { if (!this.properties) {
@ -16,6 +17,12 @@ app.registerExtension({
this.addInput("", "*"); this.addInput("", "*");
this.addOutput(this.properties.showOutputText ? "*" : "", "*"); this.addOutput(this.properties.showOutputText ? "*" : "", "*");
this.onAfterGraphConfigured = function () {
requestAnimationFrame(() => {
this.onConnectionsChange(LiteGraph.INPUT, null, true, null);
});
};
this.onConnectionsChange = function (type, index, connected, link_info) { this.onConnectionsChange = function (type, index, connected, link_info) {
this.applyOrientation(); this.applyOrientation();
@ -47,6 +54,7 @@ app.registerExtension({
const linkId = currentNode.inputs[0].link; const linkId = currentNode.inputs[0].link;
if (linkId !== null) { if (linkId !== null) {
const link = app.graph.links[linkId]; const link = app.graph.links[linkId];
if (!link) return;
const node = app.graph.getNodeById(link.origin_id); const node = app.graph.getNodeById(link.origin_id);
const type = node.constructor.type; const type = node.constructor.type;
if (type === "Reroute") { if (type === "Reroute") {
@ -54,8 +62,7 @@ app.registerExtension({
// We've found a circle // We've found a circle
currentNode.disconnectInput(link.target_slot); currentNode.disconnectInput(link.target_slot);
currentNode = null; currentNode = null;
} } else {
else {
// Move the previous node // Move the previous node
currentNode = node; currentNode = node;
} }
@ -94,8 +101,11 @@ app.registerExtension({
updateNodes.push(node); updateNodes.push(node);
} else { } else {
// We've found an output // We've found an output
const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null; const nodeOutType =
if (inputType && nodeOutType !== inputType) { node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type
? node.inputs[link.target_slot].type
: null;
if (inputType && inputType !== "*" && nodeOutType !== inputType) {
// The output doesnt match our input so disconnect it // The output doesnt match our input so disconnect it
node.disconnectInput(link.target_slot); node.disconnectInput(link.target_slot);
} else { } else {
@ -111,6 +121,9 @@ app.registerExtension({
const displayType = inputType || outputType || "*"; const displayType = inputType || outputType || "*";
const color = LGraphCanvas.link_type_colors[displayType]; const color = LGraphCanvas.link_type_colors[displayType];
let widgetConfig;
let targetWidget;
let widgetType;
// Update the types of each node // Update the types of each node
for (const node of updateNodes) { for (const node of updateNodes) {
// If we dont have an input type we are always wildcard but we'll show the output type // If we dont have an input type we are always wildcard but we'll show the output type
@ -125,10 +138,38 @@ app.registerExtension({
const link = app.graph.links[l]; const link = app.graph.links[l];
if (link) { if (link) {
link.color = color; link.color = color;
if (app.configuringGraph) continue;
const targetNode = app.graph.getNodeById(link.target_id);
const targetInput = targetNode.inputs?.[link.target_slot];
if (targetInput?.widget) {
const config = getWidgetConfig(targetInput);
if (!widgetConfig) {
widgetConfig = config[1] ?? {};
widgetType = config[0];
}
if (!targetWidget) {
targetWidget = targetNode.widgets?.find((w) => w.name === targetInput.widget.name);
}
const merged = mergeIfValid(targetInput, [config[0], widgetConfig]);
if (merged.customConfig) {
widgetConfig = merged.customConfig;
}
}
} }
} }
} }
for (const node of updateNodes) {
if (widgetConfig && outputType) {
node.inputs[0].widget = { name: "value" };
setWidgetConfig(node.inputs[0], [widgetType ?? displayType, widgetConfig], targetWidget);
} else {
setWidgetConfig(node.inputs[0], null);
}
}
if (inputNode) { if (inputNode) {
const link = app.graph.links[inputNode.inputs[0].link]; const link = app.graph.links[inputNode.inputs[0].link];
if (link) { if (link) {
@ -173,8 +214,8 @@ app.registerExtension({
}, },
{ {
// naming is inverted with respect to LiteGraphNode.horizontal // naming is inverted with respect to LiteGraphNode.horizontal
// LiteGraphNode.horizontal == true means that // LiteGraphNode.horizontal == true means that
// each slot in the inputs and outputs are layed out horizontally, // each slot in the inputs and outputs are layed out horizontally,
// which is the opposite of the visual orientation of the inputs and outputs as a node // which is the opposite of the visual orientation of the inputs and outputs as a node
content: "Set " + (this.properties.horizontal ? "Horizontal" : "Vertical"), content: "Set " + (this.properties.horizontal ? "Horizontal" : "Vertical"),
callback: () => { callback: () => {
@ -187,7 +228,7 @@ app.registerExtension({
applyOrientation() { applyOrientation() {
this.horizontal = this.properties.horizontal; this.horizontal = this.properties.horizontal;
if (this.horizontal) { if (this.horizontal) {
// we correct the input position, because LiteGraphNode.horizontal // we correct the input position, because LiteGraphNode.horizontal
// doesn't account for title presence // doesn't account for title presence
// which reroute nodes don't have // which reroute nodes don't have
this.inputs[0].pos = [this.size[0] / 2, 0]; this.inputs[0].pos = [this.size[0] / 2, 0];

View File

@ -1,5 +1,5 @@
import { app } from "../../scripts/app.js"; import { app } from "../../scripts/app.js";
import { applyTextReplacements } from "../../scripts/utils.js";
// Use widget values and dates in output filenames // Use widget values and dates in output filenames
app.registerExtension({ app.registerExtension({
@ -7,84 +7,19 @@ app.registerExtension({
async beforeRegisterNodeDef(nodeType, nodeData, app) { async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeData.name === "SaveImage") { if (nodeData.name === "SaveImage") {
const onNodeCreated = nodeType.prototype.onNodeCreated; const onNodeCreated = nodeType.prototype.onNodeCreated;
// When the SaveImage node is created we want to override the serialization of the output name widget to run our S&R
// Simple date formatter
const parts = {
d: (d) => d.getDate(),
M: (d) => d.getMonth() + 1,
h: (d) => d.getHours(),
m: (d) => d.getMinutes(),
s: (d) => d.getSeconds(),
};
const format =
Object.keys(parts)
.map((k) => k + k + "?")
.join("|") + "|yyy?y?";
function formatDate(text, date) {
return text.replace(new RegExp(format, "g"), function (text) {
if (text === "yy") return (date.getFullYear() + "").substring(2);
if (text === "yyyy") return date.getFullYear();
if (text[0] in parts) {
const p = parts[text[0]](date);
return (p + "").padStart(text.length, "0");
}
return text;
});
}
// When the SaveImage node is created we want to override the serialization of the output name widget to run our S&R
nodeType.prototype.onNodeCreated = function () { nodeType.prototype.onNodeCreated = function () {
const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined; const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
const widget = this.widgets.find((w) => w.name === "filename_prefix"); const widget = this.widgets.find((w) => w.name === "filename_prefix");
widget.serializeValue = () => { widget.serializeValue = () => {
return widget.value.replace(/%([^%]+)%/g, function (match, text) { return applyTextReplacements(app, widget.value);
const split = text.split(".");
if (split.length !== 2) {
// Special handling for dates
if (split[0].startsWith("date:")) {
return formatDate(split[0].substring(5), new Date());
}
if (text !== "width" && text !== "height") {
// Dont warn on standard replacements
console.warn("Invalid replacement pattern", text);
}
return match;
}
// Find node with matching S&R property name
let nodes = app.graph._nodes.filter((n) => n.properties?.["Node name for S&R"] === split[0]);
// If we cant, see if there is a node with that title
if (!nodes.length) {
nodes = app.graph._nodes.filter((n) => n.title === split[0]);
}
if (!nodes.length) {
console.warn("Unable to find node", split[0]);
return match;
}
if (nodes.length > 1) {
console.warn("Multiple nodes matched", split[0], "using first match");
}
const node = nodes[0];
const widget = node.widgets?.find((w) => w.name === split[1]);
if (!widget) {
console.warn("Unable to find widget", split[1], "on node", split[0], node);
return match;
}
return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_");
});
}; };
return r; return r;
}; };
} else { } else {
// When any other node is created add a property to alias the node // When any other node is created add a property to alias the node
const onNodeCreated = nodeType.prototype.onNodeCreated; const onNodeCreated = nodeType.prototype.onNodeCreated;
nodeType.prototype.onNodeCreated = function () { nodeType.prototype.onNodeCreated = function () {
const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined; const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;

View File

@ -71,24 +71,21 @@ function graphEqual(a, b, root = true) {
} }
const undoRedo = async (e) => { const undoRedo = async (e) => {
const updateState = async (source, target) => {
const prevState = source.pop();
if (prevState) {
target.push(activeState);
isOurLoad = true;
await app.loadGraphData(prevState, false);
activeState = prevState;
}
}
if (e.ctrlKey || e.metaKey) { if (e.ctrlKey || e.metaKey) {
if (e.key === "y") { if (e.key === "y") {
const prevState = redo.pop(); updateState(redo, undo);
if (prevState) {
undo.push(activeState);
isOurLoad = true;
await app.loadGraphData(prevState);
activeState = prevState;
}
return true; return true;
} else if (e.key === "z") { } else if (e.key === "z") {
const prevState = undo.pop(); updateState(undo, redo);
if (prevState) {
redo.push(activeState);
isOurLoad = true;
await app.loadGraphData(prevState);
activeState = prevState;
}
return true; return true;
} }
} }

View File

@ -1,10 +1,16 @@
import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js"; import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js";
import { app } from "../../scripts/app.js"; import { app } from "../../scripts/app.js";
import { applyTextReplacements } from "../../scripts/utils.js";
const CONVERTED_TYPE = "converted-widget"; const CONVERTED_TYPE = "converted-widget";
const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"]; const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"];
const CONFIG = Symbol(); const CONFIG = Symbol();
const GET_CONFIG = Symbol(); const GET_CONFIG = Symbol();
const TARGET = Symbol(); // Used for reroutes to specify the real target widget
export function getWidgetConfig(slot) {
return slot.widget[CONFIG] ?? slot.widget[GET_CONFIG]();
}
function getConfig(widgetName) { function getConfig(widgetName) {
const { nodeData } = this.constructor; const { nodeData } = this.constructor;
@ -100,7 +106,6 @@ function getWidgetType(config) {
return { type }; return { type };
} }
function isValidCombo(combo, obj) { function isValidCombo(combo, obj) {
// New input isnt a combo // New input isnt a combo
if (!(obj instanceof Array)) { if (!(obj instanceof Array)) {
@ -121,6 +126,31 @@ function isValidCombo(combo, obj) {
return true; return true;
} }
export function setWidgetConfig(slot, config, target) {
if (!slot.widget) return;
if (config) {
slot.widget[GET_CONFIG] = () => config;
slot.widget[TARGET] = target;
} else {
delete slot.widget;
}
if (slot.link) {
const link = app.graph.links[slot.link];
if (link) {
const originNode = app.graph.getNodeById(link.origin_id);
if (originNode.type === "PrimitiveNode") {
if (config) {
originNode.recreateWidget();
} else if(!app.configuringGraph) {
originNode.disconnectOutput(0);
originNode.onLastDisconnect();
}
}
}
}
}
export function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) { export function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) {
if (!config1) { if (!config1) {
config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG](); config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
@ -150,7 +180,7 @@ export function mergeIfValid(output, config2, forceUpdate, recreateWidget, confi
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT"; const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
for (const k of keys.values()) { for (const k of keys.values()) {
if (k !== "default" && k !== "forceInput" && k !== "defaultInput") { if (k !== "default" && k !== "forceInput" && k !== "defaultInput" && k !== "control_after_generate" && k !== "multiline") {
let v1 = config1[1][k]; let v1 = config1[1][k];
let v2 = config2[1]?.[k]; let v2 = config2[1]?.[k];
@ -405,11 +435,16 @@ app.registerExtension({
}; };
}, },
registerCustomNodes() { registerCustomNodes() {
const replacePropertyName = "Run widget replace on values";
class PrimitiveNode { class PrimitiveNode {
constructor() { constructor() {
this.addOutput("connect to widget input", "*"); this.addOutput("connect to widget input", "*");
this.serialize_widgets = true; this.serialize_widgets = true;
this.isVirtualNode = true; this.isVirtualNode = true;
if (!this.properties || !(replacePropertyName in this.properties)) {
this.addProperty(replacePropertyName, false, "boolean");
}
} }
applyToGraph(extraLinks = []) { applyToGraph(extraLinks = []) {
@ -430,18 +465,29 @@ app.registerExtension({
} }
let links = [...get_links(this).map((l) => app.graph.links[l]), ...extraLinks]; let links = [...get_links(this).map((l) => app.graph.links[l]), ...extraLinks];
let v = this.widgets?.[0].value;
if(v && this.properties[replacePropertyName]) {
v = applyTextReplacements(app, v);
}
// For each output link copy our value over the original widget value // For each output link copy our value over the original widget value
for (const linkInfo of links) { for (const linkInfo of links) {
const node = this.graph.getNodeById(linkInfo.target_id); const node = this.graph.getNodeById(linkInfo.target_id);
const input = node.inputs[linkInfo.target_slot]; const input = node.inputs[linkInfo.target_slot];
const widgetName = input.widget.name; let widget;
if (widgetName) { if (input.widget[TARGET]) {
const widget = node.widgets.find((w) => w.name === widgetName); widget = input.widget[TARGET];
if (widget) { } else {
widget.value = this.widgets[0].value; const widgetName = input.widget.name;
if (widget.callback) { if (widgetName) {
widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {}); widget = node.widgets.find((w) => w.name === widgetName);
} }
}
if (widget) {
widget.value = v;
if (widget.callback) {
widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {});
} }
} }
} }
@ -494,14 +540,13 @@ app.registerExtension({
this.#mergeWidgetConfig(); this.#mergeWidgetConfig();
if (!links?.length) { if (!links?.length) {
this.#onLastDisconnect(); this.onLastDisconnect();
} }
} }
} }
onConnectOutput(slot, type, input, target_node, target_slot) { onConnectOutput(slot, type, input, target_node, target_slot) {
// Fires before the link is made allowing us to reject it if it isn't valid // Fires before the link is made allowing us to reject it if it isn't valid
// No widget, we cant connect // No widget, we cant connect
if (!input.widget) { if (!input.widget) {
if (!(input.type in ComfyWidgets)) return false; if (!(input.type in ComfyWidgets)) return false;
@ -519,6 +564,10 @@ app.registerExtension({
#onFirstConnection(recreating) { #onFirstConnection(recreating) {
// First connection can fire before the graph is ready on initial load so random things can be missing // First connection can fire before the graph is ready on initial load so random things can be missing
if (!this.outputs[0].links) {
this.onLastDisconnect();
return;
}
const linkId = this.outputs[0].links[0]; const linkId = this.outputs[0].links[0];
const link = this.graph.links[linkId]; const link = this.graph.links[linkId];
if (!link) return; if (!link) return;
@ -546,10 +595,10 @@ app.registerExtension({
this.outputs[0].name = type; this.outputs[0].name = type;
this.outputs[0].widget = widget; this.outputs[0].widget = widget;
this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating); this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating, widget[TARGET]);
} }
#createWidget(inputData, node, widgetName, recreating) { #createWidget(inputData, node, widgetName, recreating, targetWidget) {
let type = inputData[0]; let type = inputData[0];
if (type instanceof Array) { if (type instanceof Array) {
@ -563,7 +612,9 @@ app.registerExtension({
widget = this.addWidget(type, "value", null, () => {}, {}); widget = this.addWidget(type, "value", null, () => {}, {});
} }
if (node?.widgets && widget) { if (targetWidget) {
widget.value = targetWidget.value;
} else if (node?.widgets && widget) {
const theirWidget = node.widgets.find((w) => w.name === widgetName); const theirWidget = node.widgets.find((w) => w.name === widgetName);
if (theirWidget) { if (theirWidget) {
widget.value = theirWidget.value; widget.value = theirWidget.value;
@ -577,11 +628,19 @@ app.registerExtension({
} }
addValueControlWidgets(this, widget, control_value, undefined, inputData); addValueControlWidgets(this, widget, control_value, undefined, inputData);
let filter = this.widgets_values?.[2]; let filter = this.widgets_values?.[2];
if(filter && this.widgets.length === 3) { if (filter && this.widgets.length === 3) {
this.widgets[2].value = filter; this.widgets[2].value = filter;
} }
} }
// Restore any saved control values
const controlValues = this.controlValues;
if(this.lastType === this.widgets[0].type && controlValues?.length === this.widgets.length - 1) {
for(let i = 0; i < controlValues.length; i++) {
this.widgets[i + 1].value = controlValues[i];
}
}
// When our value changes, update other widgets to reflect our changes // When our value changes, update other widgets to reflect our changes
// e.g. so LoadImage shows correct image // e.g. so LoadImage shows correct image
const callback = widget.callback; const callback = widget.callback;
@ -610,12 +669,14 @@ app.registerExtension({
} }
} }
#recreateWidget() { recreateWidget() {
const values = this.widgets.map((w) => w.value); const values = this.widgets?.map((w) => w.value);
this.#removeWidgets(); this.#removeWidgets();
this.#onFirstConnection(true); this.#onFirstConnection(true);
for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i]; if (values?.length) {
return this.widgets[0]; for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i];
}
return this.widgets?.[0];
} }
#mergeWidgetConfig() { #mergeWidgetConfig() {
@ -631,7 +692,7 @@ app.registerExtension({
if (links?.length < 2 && hasConfig) { if (links?.length < 2 && hasConfig) {
// Copy the widget options from the source // Copy the widget options from the source
if (links.length) { if (links.length) {
this.#recreateWidget(); this.recreateWidget();
} }
return; return;
@ -657,7 +718,7 @@ app.registerExtension({
// Only allow connections where the configs match // Only allow connections where the configs match
const output = this.outputs[0]; const output = this.outputs[0];
const config2 = input.widget[GET_CONFIG](); const config2 = input.widget[GET_CONFIG]();
return !!mergeIfValid.call(this, output, config2, forceUpdate, this.#recreateWidget); return !!mergeIfValid.call(this, output, config2, forceUpdate, this.recreateWidget);
} }
#removeWidgets() { #removeWidgets() {
@ -668,11 +729,20 @@ app.registerExtension({
w.onRemove(); w.onRemove();
} }
} }
// Temporarily store the current values in case the node is being recreated
// e.g. by group node conversion
this.controlValues = [];
this.lastType = this.widgets[0]?.type;
for(let i = 1; i < this.widgets.length; i++) {
this.controlValues.push(this.widgets[i].value);
}
setTimeout(() => { delete this.lastType; delete this.controlValues }, 15);
this.widgets.length = 0; this.widgets.length = 0;
} }
} }
#onLastDisconnect() { onLastDisconnect() {
// We cant remove + re-add the output here as if you drag a link over the same link // We cant remove + re-add the output here as if you drag a link over the same link
// it removes, then re-adds, causing it to break // it removes, then re-adds, causing it to break
this.outputs[0].type = "*"; this.outputs[0].type = "*";

View File

@ -48,7 +48,7 @@
EVENT_LINK_COLOR: "#A86", EVENT_LINK_COLOR: "#A86",
CONNECTING_LINK_COLOR: "#AFA", CONNECTING_LINK_COLOR: "#AFA",
MAX_NUMBER_OF_NODES: 1000, //avoid infinite loops MAX_NUMBER_OF_NODES: 10000, //avoid infinite loops
DEFAULT_POSITION: [100, 100], //default node position DEFAULT_POSITION: [100, 100], //default node position
VALID_SHAPES: ["default", "box", "round", "card"], //,"circle" VALID_SHAPES: ["default", "box", "round", "card"], //,"circle"
@ -3788,16 +3788,42 @@
/** /**
* returns the bounding of the object, used for rendering purposes * returns the bounding of the object, used for rendering purposes
* bounding is: [topleft_cornerx, topleft_cornery, width, height]
* @method getBounding * @method getBounding
* @return {Float32Array[4]} the total size * @param out {Float32Array[4]?} [optional] a place to store the output, to free garbage
* @param compute_outer {boolean?} [optional] set to true to include the shadow and connection points in the bounding calculation
* @return {Float32Array[4]} the bounding box in format of [topleft_cornerx, topleft_cornery, width, height]
*/ */
LGraphNode.prototype.getBounding = function(out) { LGraphNode.prototype.getBounding = function(out, compute_outer) {
out = out || new Float32Array(4); out = out || new Float32Array(4);
out[0] = this.pos[0] - 4; const nodePos = this.pos;
out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT; const isCollapsed = this.flags.collapsed;
out[2] = this.flags.collapsed ? (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) : this.size[0] + 4; const nodeSize = this.size;
out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT;
let left_offset = 0;
// 1 offset due to how nodes are rendered
let right_offset = 1 ;
let top_offset = 0;
let bottom_offset = 0;
if (compute_outer) {
// 4 offset for collapsed node connection points
left_offset = 4;
// 6 offset for right shadow and collapsed node connection points
right_offset = 6 + left_offset;
// 4 offset for collapsed nodes top connection points
top_offset = 4;
// 5 offset for bottom shadow and collapsed node connection points
bottom_offset = 5 + top_offset;
}
out[0] = nodePos[0] - left_offset;
out[1] = nodePos[1] - LiteGraph.NODE_TITLE_HEIGHT - top_offset;
out[2] = isCollapsed ?
(this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) + right_offset :
nodeSize[0] + right_offset;
out[3] = isCollapsed ?
LiteGraph.NODE_TITLE_HEIGHT + bottom_offset :
nodeSize[1] + LiteGraph.NODE_TITLE_HEIGHT + bottom_offset;
if (this.onBounding) { if (this.onBounding) {
this.onBounding(out); this.onBounding(out);
@ -7674,7 +7700,7 @@ LGraphNode.prototype.executeAction = function(action)
continue; continue;
} }
if (!overlapBounding(this.visible_area, n.getBounding(temp))) { if (!overlapBounding(this.visible_area, n.getBounding(temp, true))) {
continue; continue;
} //out of the visible area } //out of the visible area
@ -11336,6 +11362,7 @@ LGraphNode.prototype.executeAction = function(action)
name_element.innerText = title; name_element.innerText = title;
var value_element = dialog.querySelector(".value"); var value_element = dialog.querySelector(".value");
value_element.value = value; value_element.value = value;
value_element.select();
var input = value_element; var input = value_element;
input.addEventListener("keydown", function(e) { input.addEventListener("keydown", function(e) {

View File

@ -1559,9 +1559,12 @@ export class ComfyApp {
/** /**
* Populates the graph with the specified workflow data * Populates the graph with the specified workflow data
* @param {*} graphData A serialized graph object * @param {*} graphData A serialized graph object
* @param { boolean } clean If the graph state, e.g. images, should be cleared
*/ */
async loadGraphData(graphData) { async loadGraphData(graphData, clean = true) {
this.clean(); if (clean !== false) {
this.clean();
}
let reset_invalid_values = false; let reset_invalid_values = false;
if (!graphData) { if (!graphData) {
@ -1771,15 +1774,26 @@ export class ComfyApp {
if (parent?.updateLink) { if (parent?.updateLink) {
link = parent.updateLink(link); link = parent.updateLink(link);
} }
inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)]; if (link) {
inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
}
} }
} }
} }
output[String(node.id)] = { let node_data = {
inputs, inputs,
class_type: node.comfyClass, class_type: node.comfyClass,
}; };
if (this.ui.settings.getSettingValue("Comfy.DevMode")) {
// Ignored by the backend.
node_data["_meta"] = {
title: node.title,
}
}
output[String(node.id)] = node_data;
} }
} }
@ -2006,12 +2020,8 @@ export class ComfyApp {
async refreshComboInNodes() { async refreshComboInNodes() {
const defs = await api.getNodeDefs(); const defs = await api.getNodeDefs();
for(const nodeId in LiteGraph.registered_node_types) { for (const nodeId in defs) {
const node = LiteGraph.registered_node_types[nodeId]; this.registerNodeDef(nodeId, defs[nodeId]);
const nodeDef = defs[nodeId];
if(!nodeDef) continue;
node.nodeData = nodeDef;
} }
for(let nodeNum in this.graph._nodes) { for(let nodeNum in this.graph._nodes) {

View File

@ -177,6 +177,7 @@ LGraphCanvas.prototype.computeVisibleNodes = function () {
for (const w of node.widgets) { for (const w of node.widgets) {
if (w.element) { if (w.element) {
w.element.hidden = hidden; w.element.hidden = hidden;
w.element.style.display = hidden ? "none" : undefined;
if (hidden) { if (hidden) {
w.options.onHide?.(w); w.options.onHide?.(w);
} }

67
web/scripts/utils.js Normal file
View File

@ -0,0 +1,67 @@
// Simple date formatter
const parts = {
d: (d) => d.getDate(),
M: (d) => d.getMonth() + 1,
h: (d) => d.getHours(),
m: (d) => d.getMinutes(),
s: (d) => d.getSeconds(),
};
const format =
Object.keys(parts)
.map((k) => k + k + "?")
.join("|") + "|yyy?y?";
function formatDate(text, date) {
return text.replace(new RegExp(format, "g"), function (text) {
if (text === "yy") return (date.getFullYear() + "").substring(2);
if (text === "yyyy") return date.getFullYear();
if (text[0] in parts) {
const p = parts[text[0]](date);
return (p + "").padStart(text.length, "0");
}
return text;
});
}
export function applyTextReplacements(app, value) {
return value.replace(/%([^%]+)%/g, function (match, text) {
const split = text.split(".");
if (split.length !== 2) {
// Special handling for dates
if (split[0].startsWith("date:")) {
return formatDate(split[0].substring(5), new Date());
}
if (text !== "width" && text !== "height") {
// Dont warn on standard replacements
console.warn("Invalid replacement pattern", text);
}
return match;
}
// Find node with matching S&R property name
let nodes = app.graph._nodes.filter((n) => n.properties?.["Node name for S&R"] === split[0]);
// If we cant, see if there is a node with that title
if (!nodes.length) {
nodes = app.graph._nodes.filter((n) => n.title === split[0]);
}
if (!nodes.length) {
console.warn("Unable to find node", split[0]);
return match;
}
if (nodes.length > 1) {
console.warn("Multiple nodes matched", split[0], "using first match");
}
const node = nodes[0];
const widget = node.widgets?.find((w) => w.name === split[1]);
if (!widget) {
console.warn("Unable to find widget", split[1], "on node", split[0], node);
return match;
}
return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_");
});
}