mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Merge upstream, fix 3.12 compatibility, fix nightlies issue, fix broken node
This commit is contained in:
commit
369aeb598f
@ -1,3 +0,0 @@
|
||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2
|
||||
pause
|
||||
@ -1,2 +0,0 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --use-pytorch-cross-attention
|
||||
pause
|
||||
2
.github/workflows/test-ui.yaml
vendored
2
.github/workflows/test-ui.yaml
vendored
@ -22,5 +22,5 @@ jobs:
|
||||
run: |
|
||||
npm ci
|
||||
npm run test:generate
|
||||
npm test
|
||||
npm test -- --verbose
|
||||
working-directory: ./tests-ui
|
||||
|
||||
@ -2,6 +2,24 @@ name: "Windows Release Nightly pytorch"
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
cu:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "121"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
required: true
|
||||
type: string
|
||||
default: "12"
|
||||
|
||||
python_patch:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "1"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@ -20,21 +38,21 @@ jobs:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11.6'
|
||||
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
|
||||
- shell: bash
|
||||
run: |
|
||||
cd ..
|
||||
cp -r ComfyUI ComfyUI_copy
|
||||
curl https://www.python.org/ftp/python/3.11.6/python-3.11.6-embed-amd64.zip -o python_embeded.zip
|
||||
curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip
|
||||
unzip python_embeded.zip -d python_embeded
|
||||
cd python_embeded
|
||||
echo 'import site' >> ./python311._pth
|
||||
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
./python.exe get-pip.py
|
||||
python -m pip wheel torch torchvision torchaudio aiohttp==3.8.5 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
||||
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
||||
ls ../temp_wheel_dir
|
||||
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
||||
sed -i '1i../ComfyUI' ./python311._pth
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
cd ..
|
||||
|
||||
git clone https://github.com/comfyanonymous/taesd
|
||||
@ -49,9 +67,10 @@ jobs:
|
||||
mkdir update
|
||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||
cp -r ComfyUI/.ci/nightly/update_windows/* ./update/
|
||||
cp -r ComfyUI/.ci/nightly/windows_base_files/* ./
|
||||
|
||||
echo "..\python_embeded\python.exe .\update.py ..\ComfyUI\\
|
||||
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
||||
pause" > ./update/update_comfyui_and_python_dependencies.bat
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
||||
|
||||
9
.vscode/settings.json
vendored
9
.vscode/settings.json
vendored
@ -1,9 +0,0 @@
|
||||
{
|
||||
"path-intellisense.mappings": {
|
||||
"../": "${workspaceFolder}/web/extensions/core"
|
||||
},
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "ms-python.autopep8"
|
||||
},
|
||||
"python.formatting.provider": "none"
|
||||
}
|
||||
0
__init__.py
Normal file
0
__init__.py
Normal file
@ -53,7 +53,7 @@ class ControlNet(nn.Module):
|
||||
transformer_depth_middle=None,
|
||||
transformer_depth_output=None,
|
||||
device=None,
|
||||
operations=ops,
|
||||
operations=ops.disable_weight_init,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@ -141,24 +141,24 @@ class ControlNet(nn.Module):
|
||||
)
|
||||
]
|
||||
)
|
||||
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)])
|
||||
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
|
||||
|
||||
self.input_hint_block = TimestepEmbedSequential(
|
||||
operations.conv_nd(dims, hint_channels, 16, 3, padding=1),
|
||||
operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, 16, 16, 3, padding=1),
|
||||
operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
||||
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, 32, 32, 3, padding=1),
|
||||
operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
||||
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, 96, 96, 3, padding=1),
|
||||
operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
||||
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1))
|
||||
operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
||||
)
|
||||
|
||||
self._feature_size = model_channels
|
||||
@ -206,7 +206,7 @@ class ControlNet(nn.Module):
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
|
||||
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
@ -234,7 +234,7 @@ class ControlNet(nn.Module):
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
|
||||
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
@ -276,14 +276,14 @@ class ControlNet(nn.Module):
|
||||
operations=operations
|
||||
)]
|
||||
self.middle_block = TimestepEmbedSequential(*mid_block)
|
||||
self.middle_block_out = self.make_zero_conv(ch, operations=operations)
|
||||
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
|
||||
self._feature_size += ch
|
||||
|
||||
def make_zero_conv(self, channels, operations=None):
|
||||
return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
|
||||
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
|
||||
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
guided_hint = self.input_hint_block(hint, emb, context)
|
||||
@ -295,7 +295,7 @@ class ControlNet(nn.Module):
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
h = x
|
||||
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
||||
if guided_hint is not None:
|
||||
h = module(h, emb, context)
|
||||
|
||||
@ -55,13 +55,19 @@ fp_group = parser.add_mutually_exclusive_group()
|
||||
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||
|
||||
parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
||||
fpunet_group = parser.add_mutually_exclusive_group()
|
||||
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
||||
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
|
||||
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
||||
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
||||
|
||||
fpvae_group = parser.add_mutually_exclusive_group()
|
||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
||||
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
|
||||
|
||||
parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
|
||||
|
||||
fpte_group = parser.add_mutually_exclusive_group()
|
||||
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
|
||||
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
|
||||
@ -98,7 +104,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
|
||||
|
||||
|
||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||
|
||||
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||
|
||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||
|
||||
188
comfy/clip_model.py
Normal file
188
comfy/clip_model.py
Normal file
@ -0,0 +1,188 @@
|
||||
import torch
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
|
||||
class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
self.heads = heads
|
||||
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, mask=None, optimized_attention=None):
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
|
||||
out = optimized_attention(q, k, v, self.heads, mask)
|
||||
return self.out_proj(out)
|
||||
|
||||
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
||||
"gelu": torch.nn.functional.gelu,
|
||||
}
|
||||
|
||||
class CLIPMLP(torch.nn.Module):
|
||||
def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
|
||||
self.activation = ACTIVATIONS[activation]
|
||||
self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.activation(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
class CLIPLayer(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
|
||||
self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
|
||||
def forward(self, x, mask=None, optimized_attention=None):
|
||||
x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
|
||||
x += self.mlp(self.layer_norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class CLIPEncoder(torch.nn.Module):
|
||||
def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
|
||||
|
||||
def forward(self, x, mask=None, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None)
|
||||
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
|
||||
intermediate = None
|
||||
for i, l in enumerate(self.layers):
|
||||
x = l(x, mask, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
return x, intermediate
|
||||
|
||||
class CLIPEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
|
||||
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens):
|
||||
return self.token_embedding(input_tokens) + self.position_embedding.weight
|
||||
|
||||
|
||||
class CLIPTextModel_(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
num_layers = config_dict["num_hidden_layers"]
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
|
||||
super().__init__()
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
x = self.embeddings(input_tokens)
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
|
||||
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||
if mask is not None:
|
||||
mask += causal_mask
|
||||
else:
|
||||
mask = causal_mask
|
||||
|
||||
x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
|
||||
x = self.final_layer_norm(x)
|
||||
if i is not None and final_layer_norm_intermediate:
|
||||
i = self.final_layer_norm(i)
|
||||
|
||||
pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
|
||||
return x, i, pooled_output
|
||||
|
||||
class CLIPTextModel(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.num_layers = config_dict["num_hidden_layers"]
|
||||
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.text_model.embeddings.token_embedding
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.text_model.embeddings.token_embedding = embeddings
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.text_model(*args, **kwargs)
|
||||
|
||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
||||
|
||||
self.patch_embedding = operations.Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
num_patches = (image_size // patch_size) ** 2
|
||||
num_positions = num_patches + 1
|
||||
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
|
||||
return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device)
|
||||
|
||||
|
||||
class CLIPVision(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
num_layers = config_dict["num_hidden_layers"]
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
heads = config_dict["num_attention_heads"]
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=torch.float32, device=device, operations=operations)
|
||||
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.post_layernorm = operations.LayerNorm(embed_dim)
|
||||
|
||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||
x = self.embeddings(pixel_values)
|
||||
x = self.pre_layrnorm(x)
|
||||
#TODO: attention_mask?
|
||||
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
|
||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||
return x, i, pooled_output
|
||||
|
||||
class CLIPVisionModelProjection(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
|
||||
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
x = self.vision_model(*args, **kwargs)
|
||||
out = self.visual_projection(x[2])
|
||||
return (x[0], x[1], out)
|
||||
@ -1,36 +1,43 @@
|
||||
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils
|
||||
from .utils import load_torch_file, transformers_convert
|
||||
import os
|
||||
import torch
|
||||
import contextlib
|
||||
import json
|
||||
|
||||
from . import ops
|
||||
from . import model_patcher
|
||||
from . import model_management
|
||||
from . import clip_model
|
||||
|
||||
|
||||
class Output:
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
def __setitem__(self, key, item):
|
||||
setattr(self, key, item)
|
||||
|
||||
def clip_preprocess(image, size=224):
|
||||
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
|
||||
scale = (size / min(image.shape[1], image.shape[2]))
|
||||
image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
class ClipVisionModel():
|
||||
def __init__(self, json_config):
|
||||
config = CLIPVisionConfig.from_json_file(json_config)
|
||||
with open(json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.load_device = model_management.text_encoder_device()
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
self.dtype = torch.float32
|
||||
if model_management.should_use_fp16(self.load_device, prioritize_performance=False):
|
||||
self.dtype = torch.float16
|
||||
|
||||
with ops.use_comfy_ops(offload_device, self.dtype):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.model = CLIPVisionModelWithProjection(config)
|
||||
self.model.to(self.dtype)
|
||||
self.dtype = model_management.text_encoder_dtype(self.load_device)
|
||||
self.model = clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, ops.manual_cast)
|
||||
self.model.eval()
|
||||
|
||||
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
def load_sd(self, sd):
|
||||
@ -38,25 +45,13 @@ class ClipVisionModel():
|
||||
|
||||
def encode_image(self, image):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device))
|
||||
|
||||
if self.dtype != torch.float32:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
||||
|
||||
with precision_scope(model_management.get_autocast_device(self.load_device), torch.float32):
|
||||
outputs = self.model(pixel_values=pixel_values, output_hidden_states=True)
|
||||
|
||||
for k in outputs:
|
||||
t = outputs[k]
|
||||
if t is not None:
|
||||
if k == 'hidden_states':
|
||||
outputs["penultimate_hidden_states"] = t[-2].cpu()
|
||||
outputs["hidden_states"] = None
|
||||
else:
|
||||
outputs[k] = t.cpu()
|
||||
pixel_values = clip_preprocess(image.to(self.load_device)).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
|
||||
outputs = Output()
|
||||
outputs["last_hidden_state"] = out[0].to(model_management.intermediate_device())
|
||||
outputs["image_embeds"] = out[2].to(model_management.intermediate_device())
|
||||
outputs["penultimate_hidden_states"] = out[1].to(model_management.intermediate_device())
|
||||
return outputs
|
||||
|
||||
def convert_to_transformers(sd, prefix):
|
||||
@ -86,6 +81,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
if convert_keys:
|
||||
sd = convert_to_transformers(sd, prefix)
|
||||
if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
|
||||
# todo: fix the importlib issue here
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
|
||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||
|
||||
@ -13,6 +13,8 @@ import typing
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
import sys
|
||||
import gc
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
|
||||
@ -442,6 +444,8 @@ class PromptExecutor:
|
||||
for x in executed:
|
||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||
self.server.last_node_id = None
|
||||
if model_management.DISABLE_SMART_MEMORY:
|
||||
model_management.unload_all_models()
|
||||
|
||||
|
||||
|
||||
@ -462,6 +466,14 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
|
||||
errors = []
|
||||
valid = True
|
||||
|
||||
# todo: investigate if these are at the right indent level
|
||||
info = None
|
||||
val = None
|
||||
|
||||
validate_function_inputs = []
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
|
||||
|
||||
for x in required_inputs:
|
||||
if x not in inputs:
|
||||
error = {
|
||||
@ -591,29 +603,7 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||
# ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
||||
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
||||
for i, r3 in enumerate(ret):
|
||||
if r3 is not True:
|
||||
details = f"{x}"
|
||||
if r3 is not False:
|
||||
details += f" - {str(r3)}"
|
||||
|
||||
error = {
|
||||
"type": "custom_validation_failed",
|
||||
"message": "Custom validation failed for node",
|
||||
"details": details,
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
else:
|
||||
if x not in validate_function_inputs:
|
||||
if isinstance(type_input, list):
|
||||
if val not in type_input:
|
||||
input_config = info
|
||||
@ -640,6 +630,35 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if len(validate_function_inputs) > 0:
|
||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||
input_filtered = {}
|
||||
for x in input_data_all:
|
||||
if x in validate_function_inputs:
|
||||
input_filtered[x] = input_data_all[x]
|
||||
|
||||
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
||||
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||
for x in input_filtered:
|
||||
for i, r in enumerate(ret):
|
||||
if r is not True:
|
||||
details = f"{x}"
|
||||
if r is not False:
|
||||
details += f" - {str(r)}"
|
||||
|
||||
error = {
|
||||
"type": "custom_validation_failed",
|
||||
"message": "Custom validation failed for node",
|
||||
"details": details,
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if len(errors) > 0 or valid is not True:
|
||||
ret = (False, errors, unique_id)
|
||||
else:
|
||||
@ -771,7 +790,7 @@ class PromptQueue:
|
||||
self.server.queue_updated()
|
||||
self.not_empty.notify()
|
||||
|
||||
def get(self, timeout=None) -> typing.Tuple[QueueTuple, int]:
|
||||
def get(self, timeout=None) -> typing.Optional[typing.Tuple[QueueTuple, int]]:
|
||||
with self.not_empty:
|
||||
while len(self.queue) == 0:
|
||||
self.not_empty.wait(timeout=timeout)
|
||||
|
||||
@ -188,8 +188,7 @@ def cached_filename_list_(folder_name):
|
||||
if folder_name not in filename_list_cache:
|
||||
return None
|
||||
out = filename_list_cache[folder_name]
|
||||
if time.perf_counter() < (out[2] + 0.5):
|
||||
return out
|
||||
|
||||
for x in out[1]:
|
||||
time_modified = out[1][x]
|
||||
folder = x
|
||||
|
||||
@ -23,9 +23,9 @@ def execute_prestartup_script():
|
||||
return False
|
||||
|
||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||
node_prestartup_times = []
|
||||
for custom_node_path in node_paths:
|
||||
possible_modules = os.listdir(custom_node_path) if os.path.exists(custom_node_path) else []
|
||||
node_prestartup_times = []
|
||||
|
||||
for possible_module in possible_modules:
|
||||
module_path = os.path.join(custom_node_path, possible_module)
|
||||
@ -69,6 +69,10 @@ if args.cuda_device is not None:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
print("Set cuda device to:", args.cuda_device)
|
||||
|
||||
if args.deterministic:
|
||||
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
||||
|
||||
from .. import utils
|
||||
import yaml
|
||||
|
||||
@ -78,12 +82,12 @@ from .server import BinaryEventTypes
|
||||
from .. import model_management
|
||||
|
||||
|
||||
def prompt_worker(q: execution.PromptQueue, _server: server_module.PromptServer):
|
||||
def prompt_worker(q, _server):
|
||||
e = execution.PromptExecutor(_server)
|
||||
last_gc_collect = 0
|
||||
need_gc = False
|
||||
gc_collect_interval = 10.0
|
||||
|
||||
current_time = 0.0
|
||||
while True:
|
||||
timeout = None
|
||||
if need_gc:
|
||||
@ -94,11 +98,13 @@ def prompt_worker(q: execution.PromptQueue, _server: server_module.PromptServer)
|
||||
item, item_id = queue_item
|
||||
execution_start_time = time.perf_counter()
|
||||
prompt_id = item[1]
|
||||
_server.last_prompt_id = prompt_id
|
||||
|
||||
e.execute(item[2], prompt_id, item[3], item[4])
|
||||
need_gc = True
|
||||
q.task_done(item_id, e.outputs_ui)
|
||||
if _server.client_id is not None:
|
||||
_server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, _server.client_id)
|
||||
_server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, _server.client_id)
|
||||
|
||||
current_time = time.perf_counter()
|
||||
execution_time = current_time - execution_start_time
|
||||
@ -119,7 +125,10 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
||||
|
||||
def hijack_progress(server):
|
||||
def hook(value, total, preview_image):
|
||||
server.send_sync("progress", {"value": value, "max": total}, server.client_id)
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
|
||||
|
||||
server.send_sync("progress", progress, server.client_id)
|
||||
if preview_image is not None:
|
||||
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
|
||||
|
||||
@ -204,7 +213,7 @@ def main():
|
||||
print(f"Setting output directory to: {output_dir}")
|
||||
folder_paths.set_output_directory(output_dir)
|
||||
|
||||
#These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
|
||||
# These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
|
||||
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
|
||||
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
||||
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
||||
|
||||
@ -734,7 +734,8 @@ class PromptServer():
|
||||
message = self.encode_bytes(event, data)
|
||||
|
||||
if sid is None:
|
||||
for ws in self.sockets.values():
|
||||
sockets = list(self.sockets.values())
|
||||
for ws in sockets:
|
||||
await send_socket_catch_exception(ws.send_bytes, message)
|
||||
elif sid in self.sockets:
|
||||
await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
|
||||
@ -743,7 +744,8 @@ class PromptServer():
|
||||
message = {"type": event, "data": data}
|
||||
|
||||
if sid is None:
|
||||
for ws in self.sockets.values():
|
||||
sockets = list(self.sockets.values())
|
||||
for ws in sockets:
|
||||
await send_socket_catch_exception(ws.send_json, message)
|
||||
elif sid in self.sockets:
|
||||
await send_socket_catch_exception(self.sockets[sid].send_json, message)
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
import torch
|
||||
import math
|
||||
import os
|
||||
import contextlib
|
||||
|
||||
from . import utils
|
||||
from . import model_management
|
||||
from . import model_detection
|
||||
from . import model_patcher
|
||||
from . import ops
|
||||
|
||||
from .cldm import cldm
|
||||
from .t2i_adapter import adapter
|
||||
@ -34,13 +37,13 @@ class ControlBase:
|
||||
self.cond_hint = None
|
||||
self.strength = 1.0
|
||||
self.timestep_percent_range = (0.0, 1.0)
|
||||
self.global_average_pooling = False
|
||||
self.timestep_range = None
|
||||
|
||||
if device is None:
|
||||
device = model_management.get_torch_device()
|
||||
self.device = device
|
||||
self.previous_controlnet = None
|
||||
self.global_average_pooling = False
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
|
||||
self.cond_hint_original = cond_hint
|
||||
@ -75,6 +78,7 @@ class ControlBase:
|
||||
c.cond_hint_original = self.cond_hint_original
|
||||
c.strength = self.strength
|
||||
c.timestep_percent_range = self.timestep_percent_range
|
||||
c.global_average_pooling = self.global_average_pooling
|
||||
|
||||
def inference_memory_requirements(self, dtype):
|
||||
if self.previous_controlnet is not None:
|
||||
@ -127,12 +131,14 @@ class ControlBase:
|
||||
return out
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model, global_average_pooling=False, device=None):
|
||||
def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
|
||||
super().__init__(device)
|
||||
self.control_model = control_model
|
||||
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
||||
self.load_device = load_device
|
||||
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
self.global_average_pooling = global_average_pooling
|
||||
self.model_sampling_current = None
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
control_prev = None
|
||||
@ -146,28 +152,31 @@ class ControlNet(ControlBase):
|
||||
else:
|
||||
return None
|
||||
|
||||
dtype = self.control_model.dtype
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
|
||||
output_dtype = x_noisy.dtype
|
||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||
if self.cond_hint is not None:
|
||||
del self.cond_hint
|
||||
self.cond_hint = None
|
||||
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
|
||||
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
|
||||
|
||||
context = cond['c_crossattn']
|
||||
y = cond.get('y', None)
|
||||
if y is not None:
|
||||
y = y.to(self.control_model.dtype)
|
||||
y = y.to(dtype)
|
||||
timestep = self.model_sampling_current.timestep(t)
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
||||
return self.control_merge(None, control, control_prev, output_dtype)
|
||||
|
||||
def copy(self):
|
||||
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
|
||||
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
@ -198,10 +207,11 @@ class ControlLoraOps:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias = ops.cast_bias_weight(self, input)
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
|
||||
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
class Conv2d(torch.nn.Module):
|
||||
def __init__(
|
||||
@ -237,16 +247,11 @@ class ControlLoraOps:
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias = ops.cast_bias_weight(self, input)
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
else:
|
||||
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
def conv_nd(self, dims, *args, **kwargs):
|
||||
if dims == 2:
|
||||
return self.Conv2d(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
class ControlLora(ControlNet):
|
||||
@ -260,17 +265,26 @@ class ControlLora(ControlNet):
|
||||
controlnet_config = model.model_config.unet_config.copy()
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
|
||||
controlnet_config["operations"] = ControlLoraOps()
|
||||
self.control_model = cldm.ControlNet(**controlnet_config)
|
||||
self.manual_cast_dtype = model.manual_cast_dtype
|
||||
dtype = model.get_dtype()
|
||||
self.control_model.to(dtype)
|
||||
if self.manual_cast_dtype is None:
|
||||
class control_lora_ops(ControlLoraOps, ops.disable_weight_init):
|
||||
pass
|
||||
else:
|
||||
class control_lora_ops(ControlLoraOps, ops.manual_cast):
|
||||
pass
|
||||
dtype = self.manual_cast_dtype
|
||||
|
||||
controlnet_config["operations"] = control_lora_ops
|
||||
controlnet_config["dtype"] = dtype
|
||||
self.control_model = cldm.ControlNet(**controlnet_config)
|
||||
self.control_model.to(model_management.get_torch_device())
|
||||
diffusion_model = model.diffusion_model
|
||||
sd = diffusion_model.state_dict()
|
||||
cm = self.control_model.state_dict()
|
||||
|
||||
for k in sd:
|
||||
weight = model_management.resolve_lowvram_weight(sd[k], diffusion_model, k)
|
||||
weight = sd[k]
|
||||
try:
|
||||
utils.set_attr(self.control_model, k, weight)
|
||||
except:
|
||||
@ -367,6 +381,10 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if controlnet_config is None:
|
||||
unet_dtype = model_management.unet_dtype()
|
||||
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
||||
load_device = model_management.get_torch_device()
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
if manual_cast_dtype is not None:
|
||||
controlnet_config["operations"] = ops.manual_cast
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = cldm.ControlNet(**controlnet_config)
|
||||
@ -395,14 +413,12 @@ def load_controlnet(ckpt_path, model=None):
|
||||
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
||||
print(missing, unexpected)
|
||||
|
||||
control_model = control_model.to(unet_dtype)
|
||||
|
||||
global_average_pooling = False
|
||||
filename = os.path.splitext(ckpt_path)[0]
|
||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||
global_average_pooling = True
|
||||
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return control
|
||||
|
||||
class T2IAdapter(ControlBase):
|
||||
|
||||
@ -33,3 +33,7 @@ class SDXL(LatentFormat):
|
||||
[-0.3112, -0.2359, -0.2076]
|
||||
]
|
||||
self.taesd_decoder_name = "taesdxl_decoder"
|
||||
|
||||
class SD_X4(LatentFormat):
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.08333
|
||||
|
||||
@ -9,6 +9,7 @@ from ..modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
from ..util import instantiate_from_config, get_obj_from_str
|
||||
from ..modules.ema import LitEma
|
||||
from ... import ops
|
||||
|
||||
class DiagonalGaussianRegularizer(torch.nn.Module):
|
||||
def __init__(self, sample: bool = True):
|
||||
@ -162,12 +163,12 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
self.quant_conv = torch.nn.Conv2d(
|
||||
self.quant_conv = ops.disable_weight_init.Conv2d(
|
||||
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
||||
(1 + ddconfig["double_z"]) * embed_dim,
|
||||
1,
|
||||
)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.post_quant_conv = ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def get_autoencoder_params(self) -> list:
|
||||
|
||||
@ -18,6 +18,7 @@ if model_management.xformers_enabled():
|
||||
|
||||
from ...cli_args import args
|
||||
from ... import ops
|
||||
ops = ops.disable_weight_init
|
||||
|
||||
# CrossAttn precision handling
|
||||
if args.dont_upcast_attention:
|
||||
@ -82,16 +83,6 @@ class FeedForward(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels, dtype=None, device=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
|
||||
@ -112,19 +103,20 @@ def attention_basic(q, k, v, heads, mask=None):
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
q, k = q.float(), k.float()
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
||||
else:
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||
|
||||
del q, k
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
if mask.dtype == torch.bool:
|
||||
mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
else:
|
||||
sim += mask
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
@ -349,6 +341,18 @@ else:
|
||||
if model_management.pytorch_attention_enabled():
|
||||
optimized_attention_masked = attention_pytorch
|
||||
|
||||
def optimized_attention_for_device(device, mask=False):
|
||||
if device == torch.device("cpu"): #TODO
|
||||
if model_management.pytorch_attention_enabled():
|
||||
return attention_pytorch
|
||||
else:
|
||||
return attention_basic
|
||||
if mask:
|
||||
return optimized_attention_masked
|
||||
|
||||
return optimized_attention
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
@ -393,7 +397,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.is_res = inner_dim == dim
|
||||
|
||||
if self.ff_in:
|
||||
self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device)
|
||||
self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.disable_self_attn = disable_self_attn
|
||||
@ -413,10 +417,10 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||
self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.checkpoint = checkpoint
|
||||
self.n_heads = n_heads
|
||||
self.d_head = d_head
|
||||
@ -558,7 +562,7 @@ class SpatialTransformer(nn.Module):
|
||||
context_dim = [context_dim] * depth
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels, dtype=dtype, device=device)
|
||||
self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
if not use_linear:
|
||||
self.proj_in = operations.Conv2d(in_channels,
|
||||
inner_dim,
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Optional, Any
|
||||
|
||||
from .... import model_management
|
||||
from .... import ops
|
||||
ops = ops.disable_weight_init
|
||||
|
||||
if model_management.xformers_enabled_vae():
|
||||
import xformers
|
||||
@ -40,7 +41,7 @@ def nonlinearity(x):
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
|
||||
@ -12,13 +12,13 @@ from .util import (
|
||||
checkpoint,
|
||||
avg_pool_nd,
|
||||
zero_module,
|
||||
normalization,
|
||||
timestep_embedding,
|
||||
AlphaBlender,
|
||||
)
|
||||
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
|
||||
from ...util import exists
|
||||
from .... import ops
|
||||
ops = ops.disable_weight_init
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
@ -177,7 +177,7 @@ class ResBlock(TimestepBlock):
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
nn.GroupNorm(32, channels, dtype=dtype, device=device),
|
||||
operations.GroupNorm(32, channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
|
||||
)
|
||||
@ -206,12 +206,11 @@ class ResBlock(TimestepBlock):
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
||||
operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
|
||||
),
|
||||
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
|
||||
,
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
@ -438,9 +437,6 @@ class UNetModel(nn.Module):
|
||||
operations=ops,
|
||||
):
|
||||
super().__init__()
|
||||
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
||||
if use_spatial_transformer:
|
||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||
|
||||
if context_dim is not None:
|
||||
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||
@ -457,7 +453,6 @@ class UNetModel(nn.Module):
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
@ -503,7 +498,7 @@ class UNetModel(nn.Module):
|
||||
|
||||
if self.num_classes is not None:
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
|
||||
elif self.num_classes == "continuous":
|
||||
print("setting up linear c_adm embedding layer")
|
||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||
@ -810,13 +805,13 @@ class UNetModel(nn.Module):
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||
nn.SiLU(),
|
||||
zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
|
||||
)
|
||||
if self.predict_codebook_ids:
|
||||
self.id_predictor = nn.Sequential(
|
||||
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||
operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
|
||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
@ -842,14 +837,14 @@ class UNetModel(nn.Module):
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
h = x
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
transformer_options["block"] = ("input", id)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
|
||||
@ -41,10 +41,14 @@ class AbstractLowScaleModel(nn.Module):
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
||||
def q_sample(self, x_start, t, noise=None, seed=None):
|
||||
if noise is None:
|
||||
if seed is None:
|
||||
noise = torch.randn_like(x_start)
|
||||
else:
|
||||
noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device)
|
||||
return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
|
||||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise)
|
||||
|
||||
def forward(self, x):
|
||||
return x, None
|
||||
@ -69,12 +73,12 @@ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
|
||||
super().__init__(noise_schedule_config=noise_schedule_config)
|
||||
self.max_noise_level = max_noise_level
|
||||
|
||||
def forward(self, x, noise_level=None):
|
||||
def forward(self, x, noise_level=None, seed=None):
|
||||
if noise_level is None:
|
||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||
else:
|
||||
assert isinstance(noise_level, torch.Tensor)
|
||||
z = self.q_sample(x, noise_level)
|
||||
z = self.q_sample(x, noise_level, seed=seed)
|
||||
return z, noise_level
|
||||
|
||||
|
||||
|
||||
@ -16,7 +16,6 @@ import numpy as np
|
||||
from einops import repeat, rearrange
|
||||
|
||||
from ...util import instantiate_from_config
|
||||
from .... import ops
|
||||
|
||||
class AlphaBlender(nn.Module):
|
||||
strategies = ["learned", "fixed", "learned_with_images"]
|
||||
@ -52,9 +51,9 @@ class AlphaBlender(nn.Module):
|
||||
if self.merge_strategy == "fixed":
|
||||
# make shape compatible
|
||||
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
|
||||
alpha = self.mix_factor
|
||||
alpha = self.mix_factor.to(image_only_indicator.device)
|
||||
elif self.merge_strategy == "learned":
|
||||
alpha = torch.sigmoid(self.mix_factor)
|
||||
alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device))
|
||||
# make shape compatible
|
||||
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
|
||||
elif self.merge_strategy == "learned_with_images":
|
||||
@ -62,7 +61,7 @@ class AlphaBlender(nn.Module):
|
||||
alpha = torch.where(
|
||||
image_only_indicator.bool(),
|
||||
torch.ones(1, 1, device=image_only_indicator.device),
|
||||
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
|
||||
rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"),
|
||||
)
|
||||
alpha = rearrange(alpha, self.rearrange_pattern)
|
||||
# make shape compatible
|
||||
@ -273,46 +272,6 @@ def mean_flat(tensor):
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels, dtype=None):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(32, channels, dtype=dtype)
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return ops.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return ops.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
|
||||
@ -15,12 +15,12 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
|
||||
|
||||
def scale(self, x):
|
||||
# re-normalize to centered mean and unit variance
|
||||
x = (x - self.data_mean) * 1. / self.data_std
|
||||
x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device)
|
||||
return x
|
||||
|
||||
def unscale(self, x):
|
||||
# back to original data stats
|
||||
x = (x * self.data_std) + self.data_mean
|
||||
x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
|
||||
return x
|
||||
|
||||
def forward(self, x, noise_level=None):
|
||||
|
||||
@ -5,6 +5,7 @@ import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from ... import ops
|
||||
ops = ops.disable_weight_init
|
||||
|
||||
from .diffusionmodules.model import (
|
||||
AttnBlock,
|
||||
@ -81,14 +82,14 @@ class VideoResBlock(ResnetBlock):
|
||||
|
||||
x = self.time_stack(x, temb)
|
||||
|
||||
alpha = self.get_alpha(bs=b // timesteps)
|
||||
alpha = self.get_alpha(bs=b // timesteps).to(x.device)
|
||||
x = alpha * x + (1.0 - alpha) * x_mix
|
||||
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
return x
|
||||
|
||||
|
||||
class AE3DConv(torch.nn.Conv2d):
|
||||
class AE3DConv(ops.Conv2d):
|
||||
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
|
||||
super().__init__(in_channels, out_channels, *args, **kwargs)
|
||||
if isinstance(video_kernel_size, Iterable):
|
||||
@ -96,7 +97,7 @@ class AE3DConv(torch.nn.Conv2d):
|
||||
else:
|
||||
padding = int(video_kernel_size // 2)
|
||||
|
||||
self.time_mix_conv = torch.nn.Conv3d(
|
||||
self.time_mix_conv = ops.Conv3d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=video_kernel_size,
|
||||
@ -166,7 +167,7 @@ class AttnVideoBlock(AttnBlock):
|
||||
emb = emb[:, None, :]
|
||||
x_mix = x_mix + emb
|
||||
|
||||
alpha = self.get_alpha()
|
||||
alpha = self.get_alpha().to(x.device)
|
||||
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
|
||||
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
|
||||
|
||||
|
||||
@ -43,7 +43,7 @@ def load_lora(lora, to_load):
|
||||
if mid_name is not None and mid_name in lora.keys():
|
||||
mid = lora[mid_name]
|
||||
loaded_keys.add(mid_name)
|
||||
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
|
||||
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
|
||||
loaded_keys.add(A_name)
|
||||
loaded_keys.add(B_name)
|
||||
|
||||
@ -64,7 +64,7 @@ def load_lora(lora, to_load):
|
||||
loaded_keys.add(hada_t1_name)
|
||||
loaded_keys.add(hada_t2_name)
|
||||
|
||||
patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)
|
||||
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
|
||||
loaded_keys.add(hada_w1_a_name)
|
||||
loaded_keys.add(hada_w1_b_name)
|
||||
loaded_keys.add(hada_w2_a_name)
|
||||
@ -116,8 +116,19 @@ def load_lora(lora, to_load):
|
||||
loaded_keys.add(lokr_t2_name)
|
||||
|
||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
|
||||
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
|
||||
|
||||
#glora
|
||||
a1_name = "{}.a1.weight".format(x)
|
||||
a2_name = "{}.a2.weight".format(x)
|
||||
b1_name = "{}.b1.weight".format(x)
|
||||
b2_name = "{}.b2.weight".format(x)
|
||||
if a1_name in lora:
|
||||
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
|
||||
loaded_keys.add(a1_name)
|
||||
loaded_keys.add(a2_name)
|
||||
loaded_keys.add(b1_name)
|
||||
loaded_keys.add(b2_name)
|
||||
|
||||
w_norm_name = "{}.w_norm".format(x)
|
||||
b_norm_name = "{}.b_norm".format(x)
|
||||
@ -126,21 +137,21 @@ def load_lora(lora, to_load):
|
||||
|
||||
if w_norm is not None:
|
||||
loaded_keys.add(w_norm_name)
|
||||
patch_dict[to_load[x]] = (w_norm,)
|
||||
patch_dict[to_load[x]] = ("diff", (w_norm,))
|
||||
if b_norm is not None:
|
||||
loaded_keys.add(b_norm_name)
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
|
||||
|
||||
diff_name = "{}.diff".format(x)
|
||||
diff_weight = lora.get(diff_name, None)
|
||||
if diff_weight is not None:
|
||||
patch_dict[to_load[x]] = (diff_weight,)
|
||||
patch_dict[to_load[x]] = ("diff", (diff_weight,))
|
||||
loaded_keys.add(diff_name)
|
||||
|
||||
diff_bias_name = "{}.diff_b".format(x)
|
||||
diff_bias = lora.get(diff_bias_name, None)
|
||||
if diff_bias is not None:
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
||||
loaded_keys.add(diff_bias_name)
|
||||
|
||||
for x in lora.keys():
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import torch
|
||||
from .ldm.modules.diffusionmodules.openaimodel import UNetModel
|
||||
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
from .ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from . import model_management
|
||||
from . import conds
|
||||
from . import ops
|
||||
from enum import Enum
|
||||
import contextlib
|
||||
from . import utils
|
||||
|
||||
class ModelType(Enum):
|
||||
@ -13,7 +15,7 @@ class ModelType(Enum):
|
||||
V_PREDICTION_EDM = 3
|
||||
|
||||
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
|
||||
from .model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
@ -40,9 +42,14 @@ class BaseModel(torch.nn.Module):
|
||||
unet_config = model_config.unet_config
|
||||
self.latent_format = model_config.latent_format
|
||||
self.model_config = model_config
|
||||
self.manual_cast_dtype = model_config.manual_cast_dtype
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
self.diffusion_model = UNetModel(**unet_config, device=device)
|
||||
if self.manual_cast_dtype is not None:
|
||||
operations = ops.manual_cast
|
||||
else:
|
||||
operations = ops.disable_weight_init
|
||||
self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
|
||||
self.model_type = model_type
|
||||
self.model_sampling = model_sampling(model_config, model_type)
|
||||
|
||||
@ -61,15 +68,21 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
context = c_crossattn
|
||||
dtype = self.get_dtype()
|
||||
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
|
||||
xc = xc.to(dtype)
|
||||
t = self.model_sampling.timestep(t).float()
|
||||
context = context.to(dtype)
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
extra = kwargs[o]
|
||||
if hasattr(extra, "to"):
|
||||
extra = extra.to(dtype)
|
||||
if hasattr(extra, "dtype"):
|
||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||
extra = extra.to(dtype)
|
||||
extra_conds[o] = extra
|
||||
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
@ -117,6 +130,10 @@ class BaseModel(torch.nn.Module):
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['y'] = conds.CONDRegular(adm)
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
|
||||
return out
|
||||
|
||||
def load_model_weights(self, sd, unet_prefix=""):
|
||||
@ -144,11 +161,7 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
|
||||
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
|
||||
unet_sd = self.diffusion_model.state_dict()
|
||||
unet_state_dict = {}
|
||||
for k in unet_sd:
|
||||
unet_state_dict[k] = model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
|
||||
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
|
||||
if self.get_dtype() == torch.float16:
|
||||
@ -165,9 +178,12 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
def memory_required(self, input_shape):
|
||||
if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention():
|
||||
dtype = self.get_dtype()
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
#TODO: this needs to be tweaked
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
return (area * model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024)
|
||||
return (area * model_management.dtype_size(dtype) / 50) * (1024 * 1024)
|
||||
else:
|
||||
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||
@ -307,9 +323,75 @@ class SVD_img2vid(BaseModel):
|
||||
|
||||
out['c_concat'] = conds.CONDNoiseShape(latent_image)
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
|
||||
|
||||
if "time_conditioning" in kwargs:
|
||||
out["time_context"] = conds.CONDCrossAttn(kwargs["time_conditioning"])
|
||||
|
||||
out['image_only_indicator'] = conds.CONDConstant(torch.zeros((1,), device=device))
|
||||
out['num_video_frames'] = conds.CONDConstant(noise.shape[0])
|
||||
return out
|
||||
|
||||
class Stable_Zero123(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.cc_projection = ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device)
|
||||
self.cc_projection.weight.copy_(cc_projection_weight)
|
||||
self.cc_projection.bias.copy_(cc_projection_bias)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
latent_image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
|
||||
if latent_image is None:
|
||||
latent_image = torch.zeros_like(noise)
|
||||
|
||||
if latent_image.shape[1:] != noise.shape[1:]:
|
||||
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])
|
||||
|
||||
out['c_concat'] = conds.CONDNoiseShape(latent_image)
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
if cross_attn.shape[-1] != 768:
|
||||
cross_attn = self.cc_projection(cross_attn)
|
||||
out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
|
||||
return out
|
||||
|
||||
class SD_X4Upscaler(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
image = kwargs.get("concat_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
noise_augment = kwargs.get("noise_augmentation", 0.0)
|
||||
device = kwargs["device"]
|
||||
seed = kwargs["seed"] - 10
|
||||
|
||||
noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment)
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros_like(noise)[:,:3]
|
||||
|
||||
if image.shape[1:] != noise.shape[1:]:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
noise_level = torch.tensor([noise_level], device=device)
|
||||
if noise_augment > 0:
|
||||
image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed)
|
||||
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
out['c_concat'] = conds.CONDNoiseShape(image)
|
||||
out['y'] = conds.CONDRegular(noise_level)
|
||||
return out
|
||||
|
||||
@ -34,7 +34,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
unet_config = {
|
||||
"use_checkpoint": False,
|
||||
"image_size": 32,
|
||||
"out_channels": 4,
|
||||
"use_spatial_transformer": True,
|
||||
"legacy": False
|
||||
}
|
||||
@ -50,6 +49,12 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
||||
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
||||
|
||||
out_key = '{}out.2.weight'.format(key_prefix)
|
||||
if out_key in state_dict:
|
||||
out_channels = state_dict[out_key].shape[0]
|
||||
else:
|
||||
out_channels = 4
|
||||
|
||||
num_res_blocks = []
|
||||
channel_mult = []
|
||||
attention_resolutions = []
|
||||
@ -122,6 +127,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
transformer_depth_middle = -1
|
||||
|
||||
unet_config["in_channels"] = in_channels
|
||||
unet_config["out_channels"] = out_channels
|
||||
unet_config["model_channels"] = model_channels
|
||||
unet_config["num_res_blocks"] = num_res_blocks
|
||||
unet_config["transformer_depth"] = transformer_depth
|
||||
@ -289,7 +295,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B]
|
||||
Segmind_Vega = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 1, 1, 2, 2], 'transformer_depth_output': [0, 0, 0, 1, 1, 1, 2, 2, 2],
|
||||
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega]
|
||||
|
||||
for unet_config in supported_models:
|
||||
matches = True
|
||||
|
||||
@ -2,6 +2,7 @@ import psutil
|
||||
from enum import Enum
|
||||
from .cli_args import args
|
||||
from . import utils
|
||||
|
||||
import torch
|
||||
import sys
|
||||
|
||||
@ -28,6 +29,10 @@ total_vram = 0
|
||||
lowvram_available = True
|
||||
xpu_available = False
|
||||
|
||||
if args.deterministic:
|
||||
print("Using deterministic algorithms for pytorch")
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
|
||||
directml_enabled = False
|
||||
if args.directml is not None:
|
||||
import torch_directml
|
||||
@ -182,6 +187,9 @@ except:
|
||||
if is_intel_xpu():
|
||||
VAE_DTYPE = torch.bfloat16
|
||||
|
||||
if args.cpu_vae:
|
||||
VAE_DTYPE = torch.float32
|
||||
|
||||
if args.fp16_vae:
|
||||
VAE_DTYPE = torch.float16
|
||||
elif args.bf16_vae:
|
||||
@ -214,15 +222,8 @@ if args.force_fp16 or cpu_state == CPUState.MPS:
|
||||
FORCE_FP16 = True
|
||||
|
||||
if lowvram_available:
|
||||
try:
|
||||
import accelerate
|
||||
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||
vram_state = set_vram_to
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
print("ERROR: LOW VRAM MODE NEEDS accelerate.")
|
||||
lowvram_available = False
|
||||
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||
vram_state = set_vram_to
|
||||
|
||||
|
||||
if cpu_state != CPUState.GPU:
|
||||
@ -262,6 +263,14 @@ print("VAE dtype:", VAE_DTYPE)
|
||||
|
||||
current_loaded_models = []
|
||||
|
||||
def module_size(module):
|
||||
module_mem = 0
|
||||
sd = module.state_dict()
|
||||
for k in sd:
|
||||
t = sd[k]
|
||||
module_mem += t.nelement() * t.element_size()
|
||||
return module_mem
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
@ -294,8 +303,20 @@ class LoadedModel:
|
||||
|
||||
if lowvram_model_memory > 0:
|
||||
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
|
||||
device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
||||
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
|
||||
mem_counter = 0
|
||||
for m in self.real_model.modules():
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
module_mem = module_size(m)
|
||||
if mem_counter + module_mem < lowvram_model_memory:
|
||||
m.to(self.device)
|
||||
mem_counter += module_mem
|
||||
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
|
||||
m.to(self.device)
|
||||
mem_counter += module_size(m)
|
||||
print("lowvram: loaded module regularly", m)
|
||||
|
||||
self.model_accelerated = True
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||
@ -305,7 +326,11 @@ class LoadedModel:
|
||||
|
||||
def model_unload(self):
|
||||
if self.model_accelerated:
|
||||
accelerate.hooks.remove_hook_from_submodules(self.real_model)
|
||||
for m in self.real_model.modules():
|
||||
if hasattr(m, "prev_comfy_cast_weights"):
|
||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||
del m.prev_comfy_cast_weights
|
||||
|
||||
self.model_accelerated = False
|
||||
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
@ -398,14 +423,14 @@ def load_models_gpu(models, memory_required=0):
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||
model_size = loaded_model.model_memory_required(torch_dev)
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
||||
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
||||
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
|
||||
vram_set_state = VRAMState.LOW_VRAM
|
||||
else:
|
||||
lowvram_model_memory = 0
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
lowvram_model_memory = 256 * 1024 * 1024
|
||||
lowvram_model_memory = 64 * 1024 * 1024
|
||||
|
||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
@ -430,6 +455,13 @@ def dtype_size(dtype):
|
||||
dtype_size = 4
|
||||
if dtype == torch.float16 or dtype == torch.bfloat16:
|
||||
dtype_size = 2
|
||||
elif dtype == torch.float32:
|
||||
dtype_size = 4
|
||||
else:
|
||||
try:
|
||||
dtype_size = dtype.itemsize
|
||||
except: #Old pytorch doesn't have .itemsize
|
||||
pass
|
||||
return dtype_size
|
||||
|
||||
def unet_offload_device():
|
||||
@ -459,10 +491,30 @@ def unet_inital_load_device(parameters, dtype):
|
||||
def unet_dtype(device=None, model_params=0):
|
||||
if args.bf16_unet:
|
||||
return torch.bfloat16
|
||||
if args.fp16_unet:
|
||||
return torch.float16
|
||||
if args.fp8_e4m3fn_unet:
|
||||
return torch.float8_e4m3fn
|
||||
if args.fp8_e5m2_unet:
|
||||
return torch.float8_e5m2
|
||||
if should_use_fp16(device=device, model_params=model_params):
|
||||
return torch.float16
|
||||
return torch.float32
|
||||
|
||||
# None means no manual cast
|
||||
def unet_manual_cast(weight_dtype, inference_device):
|
||||
if weight_dtype == torch.float32:
|
||||
return None
|
||||
|
||||
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
|
||||
if fp16_supported and weight_dtype == torch.float16:
|
||||
return None
|
||||
|
||||
if fp16_supported:
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
def text_encoder_offload_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
@ -492,12 +544,23 @@ def text_encoder_dtype(device=None):
|
||||
elif args.fp32_text_enc:
|
||||
return torch.float32
|
||||
|
||||
if is_device_cpu(device):
|
||||
return torch.float16
|
||||
|
||||
if should_use_fp16(device, prioritize_performance=False):
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
def intermediate_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def vae_device():
|
||||
if args.cpu_vae:
|
||||
return torch.device("cpu")
|
||||
return get_torch_device()
|
||||
|
||||
def vae_offload_device():
|
||||
@ -515,6 +578,22 @@ def get_autocast_device(dev):
|
||||
return dev.type
|
||||
return "cuda"
|
||||
|
||||
def supports_dtype(device, dtype): #TODO
|
||||
if dtype == torch.float32:
|
||||
return True
|
||||
if is_device_cpu(device):
|
||||
return False
|
||||
if dtype == torch.float16:
|
||||
return True
|
||||
if dtype == torch.bfloat16:
|
||||
return True
|
||||
return False
|
||||
|
||||
def device_supports_non_blocking(device):
|
||||
if is_device_mps(device):
|
||||
return False #pytorch bug? mps doesn't support non blocking
|
||||
return True
|
||||
|
||||
def cast_to_device(tensor, device, dtype, copy=False):
|
||||
device_supports_cast = False
|
||||
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
||||
@ -525,15 +604,17 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
elif is_intel_xpu():
|
||||
device_supports_cast = True
|
||||
|
||||
non_blocking = device_supports_non_blocking(device)
|
||||
|
||||
if device_supports_cast:
|
||||
if copy:
|
||||
if tensor.device == device:
|
||||
return tensor.to(dtype, copy=copy)
|
||||
return tensor.to(device, copy=copy).to(dtype)
|
||||
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
|
||||
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
||||
else:
|
||||
return tensor.to(device).to(dtype)
|
||||
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
||||
else:
|
||||
return tensor.to(dtype).to(device, copy=copy)
|
||||
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
|
||||
|
||||
def xformers_enabled():
|
||||
global directml_enabled
|
||||
@ -687,11 +768,11 @@ def soft_empty_cache(force=False):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
def resolve_lowvram_weight(weight, model, key):
|
||||
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
|
||||
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
|
||||
op = utils.get_attr(model, '.'.join(key_split[:-1]))
|
||||
weight = op._hf_hook.weights_map[key_split[-1]]
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
|
||||
|
||||
def resolve_lowvram_weight(weight, model, key): #TODO: remove
|
||||
return weight
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
|
||||
@ -28,13 +28,9 @@ class ModelPatcher:
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
model_sd = self.model.state_dict()
|
||||
size = 0
|
||||
for k in model_sd:
|
||||
t = model_sd[k]
|
||||
size += t.nelement() * t.element_size()
|
||||
self.size = size
|
||||
self.size = model_management.module_size(self.model)
|
||||
self.model_keys = set(model_sd.keys())
|
||||
return size
|
||||
return self.size
|
||||
|
||||
def clone(self):
|
||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
|
||||
@ -55,11 +51,18 @@ class ModelPatcher:
|
||||
def memory_required(self, input_shape):
|
||||
return self.model.memory_required(input_shape=input_shape)
|
||||
|
||||
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
||||
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||
else:
|
||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||
if disable_cfg1_optimization:
|
||||
self.model_options["disable_cfg1_optimization"] = True
|
||||
|
||||
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
||||
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
|
||||
if disable_cfg1_optimization:
|
||||
self.model_options["disable_cfg1_optimization"] = True
|
||||
|
||||
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||
@ -70,13 +73,17 @@ class ModelPatcher:
|
||||
to["patches"] = {}
|
||||
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||
|
||||
def set_model_patch_replace(self, patch, name, block_name, number):
|
||||
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches_replace" not in to:
|
||||
to["patches_replace"] = {}
|
||||
if name not in to["patches_replace"]:
|
||||
to["patches_replace"][name] = {}
|
||||
to["patches_replace"][name][(block_name, number)] = patch
|
||||
if transformer_index is not None:
|
||||
block = (block_name, number, transformer_index)
|
||||
else:
|
||||
block = (block_name, number)
|
||||
to["patches_replace"][name][block] = patch
|
||||
|
||||
def set_model_attn1_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn1_patch")
|
||||
@ -84,11 +91,11 @@ class ModelPatcher:
|
||||
def set_model_attn2_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn2_patch")
|
||||
|
||||
def set_model_attn1_replace(self, patch, block_name, number):
|
||||
self.set_model_patch_replace(patch, "attn1", block_name, number)
|
||||
def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
|
||||
self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
|
||||
|
||||
def set_model_attn2_replace(self, patch, block_name, number):
|
||||
self.set_model_patch_replace(patch, "attn2", block_name, number)
|
||||
def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
|
||||
self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
|
||||
|
||||
def set_model_attn1_output_patch(self, patch):
|
||||
self.set_model_patch(patch, "attn1_output_patch")
|
||||
@ -167,40 +174,41 @@ class ModelPatcher:
|
||||
sd.pop(k)
|
||||
return sd
|
||||
|
||||
def patch_model(self, device_to=None):
|
||||
def patch_model(self, device_to=None, patch_weights=True):
|
||||
for k in self.object_patches:
|
||||
old = getattr(self.model, k)
|
||||
if k not in self.object_patches_backup:
|
||||
self.object_patches_backup[k] = old
|
||||
setattr(self.model, k, self.object_patches[k])
|
||||
|
||||
model_sd = self.model_state_dict()
|
||||
for key in self.patches:
|
||||
if key not in model_sd:
|
||||
print("could not patch. key doesn't exist in model:", key)
|
||||
continue
|
||||
if patch_weights:
|
||||
model_sd = self.model_state_dict()
|
||||
for key in self.patches:
|
||||
if key not in model_sd:
|
||||
print("could not patch. key doesn't exist in model:", key)
|
||||
continue
|
||||
|
||||
weight = model_sd[key]
|
||||
weight = model_sd[key]
|
||||
|
||||
inplace_update = self.weight_inplace_update
|
||||
inplace_update = self.weight_inplace_update
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
||||
if key not in self.backup:
|
||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
||||
|
||||
if device_to is not None:
|
||||
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||
if inplace_update:
|
||||
utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
utils.set_attr(self.model, key, out_weight)
|
||||
del temp_weight
|
||||
|
||||
if device_to is not None:
|
||||
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||
if inplace_update:
|
||||
utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
utils.set_attr(self.model, key, out_weight)
|
||||
del temp_weight
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
|
||||
return self.model
|
||||
|
||||
@ -217,13 +225,19 @@ class ModelPatcher:
|
||||
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
|
||||
|
||||
if len(v) == 1:
|
||||
patch_type = "diff"
|
||||
elif len(v) == 2:
|
||||
patch_type = v[0]
|
||||
v = v[1]
|
||||
|
||||
if patch_type == "diff":
|
||||
w1 = v[0]
|
||||
if alpha != 0.0:
|
||||
if w1.shape != weight.shape:
|
||||
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||
else:
|
||||
weight += alpha * model_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||
elif len(v) == 4: #lora/locon
|
||||
elif patch_type == "lora": #lora/locon
|
||||
mat1 = model_management.cast_to_device(v[0], weight.device, torch.float32)
|
||||
mat2 = model_management.cast_to_device(v[1], weight.device, torch.float32)
|
||||
if v[2] is not None:
|
||||
@ -237,7 +251,7 @@ class ModelPatcher:
|
||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
elif len(v) == 8: #lokr
|
||||
elif patch_type == "lokr":
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
w1_a = v[3]
|
||||
@ -276,7 +290,7 @@ class ModelPatcher:
|
||||
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
else: #loha
|
||||
elif patch_type == "loha":
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
if v[2] is not None:
|
||||
@ -305,6 +319,18 @@ class ModelPatcher:
|
||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
||||
except Exception as e:
|
||||
print("ERROR", key, e)
|
||||
elif patch_type == "glora":
|
||||
if v[4] is not None:
|
||||
alpha *= v[4] / v[0].shape[0]
|
||||
|
||||
a1 = model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
|
||||
a2 = model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
|
||||
b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
|
||||
b2 = model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
||||
|
||||
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
|
||||
else:
|
||||
print("patch type not recognized", patch_type, key)
|
||||
|
||||
return weight
|
||||
|
||||
|
||||
@ -22,10 +22,17 @@ class V_PREDICTION(EPS):
|
||||
class ModelSamplingDiscrete(torch.nn.Module):
|
||||
def __init__(self, model_config=None):
|
||||
super().__init__()
|
||||
beta_schedule = "linear"
|
||||
|
||||
if model_config is not None:
|
||||
beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule)
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||
sampling_settings = model_config.sampling_settings
|
||||
else:
|
||||
sampling_settings = {}
|
||||
|
||||
beta_schedule = sampling_settings.get("beta_schedule", "linear")
|
||||
linear_start = sampling_settings.get("linear_start", 0.00085)
|
||||
linear_end = sampling_settings.get("linear_end", 0.012)
|
||||
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
|
||||
self.sigma_data = 1.0
|
||||
|
||||
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
|
||||
@ -6,7 +6,7 @@ import hashlib
|
||||
import math
|
||||
import random
|
||||
|
||||
from PIL import Image, ImageOps
|
||||
from PIL import Image, ImageOps, ImageSequence
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
@ -930,8 +930,8 @@ class GLIGENTextBoxApply:
|
||||
return (c, )
|
||||
|
||||
class EmptyLatentImage:
|
||||
def __init__(self, device="cpu"):
|
||||
self.device = device
|
||||
def __init__(self):
|
||||
self.device = comfy.model_management.intermediate_device()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -944,7 +944,7 @@ class EmptyLatentImage:
|
||||
CATEGORY = "latent"
|
||||
|
||||
def generate(self, width, height, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
|
||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
|
||||
return ({"samples":latent}, )
|
||||
|
||||
|
||||
@ -1395,17 +1395,30 @@ class LoadImage:
|
||||
FUNCTION = "load_image"
|
||||
def load_image(self, image):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
i = Image.open(image_path)
|
||||
i = ImageOps.exif_transpose(i)
|
||||
image = i.convert("RGB")
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = torch.from_numpy(image)[None,]
|
||||
if 'A' in i.getbands():
|
||||
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
img = Image.open(image_path)
|
||||
output_images = []
|
||||
output_masks = []
|
||||
for i in ImageSequence.Iterator(img):
|
||||
i = ImageOps.exif_transpose(i)
|
||||
image = i.convert("RGB")
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = torch.from_numpy(image)[None,]
|
||||
if 'A' in i.getbands():
|
||||
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
else:
|
||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||
output_images.append(image)
|
||||
output_masks.append(mask.unsqueeze(0))
|
||||
|
||||
if len(output_images) > 1:
|
||||
output_image = torch.cat(output_images, dim=0)
|
||||
output_mask = torch.cat(output_masks, dim=0)
|
||||
else:
|
||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||
return (image, mask.unsqueeze(0))
|
||||
output_image = output_images[0]
|
||||
output_mask = output_masks[0]
|
||||
|
||||
return (output_image, output_mask)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image):
|
||||
@ -1463,13 +1476,10 @@ class LoadImageMask:
|
||||
return m.digest().hex()
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, image, channel):
|
||||
def VALIDATE_INPUTS(s, image):
|
||||
if not folder_paths.exists_annotated_filepath(image):
|
||||
return "Invalid image file: {}".format(image)
|
||||
|
||||
if channel not in s._color_channels:
|
||||
return "Invalid color channel: {}".format(channel)
|
||||
|
||||
return True
|
||||
|
||||
class ImageScale:
|
||||
|
||||
139
comfy/ops.py
139
comfy/ops.py
@ -1,40 +1,115 @@
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
import comfy.model_management
|
||||
|
||||
class Linear(torch.nn.Linear):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
def cast_bias_weight(s, input):
|
||||
bias = None
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
return weight, bias
|
||||
|
||||
class Conv2d(torch.nn.Conv2d):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
class Conv3d(torch.nn.Conv3d):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
class disable_weight_init:
|
||||
class Linear(torch.nn.Linear):
|
||||
comfy_cast_weights = False
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
if dims == 2:
|
||||
return Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return Conv3d(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
@contextmanager
|
||||
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
|
||||
old_torch_nn_linear = torch.nn.Linear
|
||||
force_device = device
|
||||
force_dtype = dtype
|
||||
def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
||||
if force_device is not None:
|
||||
device = force_device
|
||||
if force_dtype is not None:
|
||||
dtype = force_dtype
|
||||
return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
torch.nn.Linear = linear_with_dtype
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.nn.Linear = old_torch_nn_linear
|
||||
class Conv2d(torch.nn.Conv2d):
|
||||
comfy_cast_weights = False
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class Conv3d(torch.nn.Conv3d):
|
||||
comfy_cast_weights = False
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class GroupNorm(torch.nn.GroupNorm):
|
||||
comfy_cast_weights = False
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
comfy_cast_weights = False
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def conv_nd(s, dims, *args, **kwargs):
|
||||
if dims == 2:
|
||||
return s.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return s.Conv3d(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
class manual_cast(disable_weight_init):
|
||||
class Linear(disable_weight_init.Linear):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class Conv2d(disable_weight_init.Conv2d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class Conv3d(disable_weight_init.Conv3d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class GroupNorm(disable_weight_init.GroupNorm):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class LayerNorm(disable_weight_init.LayerNorm):
|
||||
comfy_cast_weights = True
|
||||
|
||||
9
comfy/package_data_path_helper.py
Normal file
9
comfy/package_data_path_helper.py
Normal file
@ -0,0 +1,9 @@
|
||||
from importlib.resources import path
|
||||
import os
|
||||
|
||||
|
||||
def get_editable_resource_path(caller_file, *package_path):
|
||||
filename = os.path.join(os.path.dirname(os.path.realpath(caller_file)), package_path[-1])
|
||||
if not os.path.exists(filename):
|
||||
filename = path(*package_path)
|
||||
return filename
|
||||
@ -47,7 +47,8 @@ def convert_cond(cond):
|
||||
temp = c[1].copy()
|
||||
model_conds = temp.get("model_conds", {})
|
||||
if c[0] is not None:
|
||||
model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0])
|
||||
model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) #TODO: remove
|
||||
temp["cross_attn"] = c[0]
|
||||
temp["model_conds"] = model_conds
|
||||
out.append(temp)
|
||||
return out
|
||||
@ -98,10 +99,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
||||
sampler = samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.cpu()
|
||||
samples = samples.to(model_management.intermediate_device())
|
||||
|
||||
cleanup_additional_models(models)
|
||||
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
|
||||
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
|
||||
return samples
|
||||
|
||||
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
@ -111,8 +112,8 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent
|
||||
sigmas = sigmas.to(model.load_device)
|
||||
|
||||
samples = samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
samples = samples.cpu()
|
||||
samples = samples.to(model_management.intermediate_device())
|
||||
cleanup_additional_models(models)
|
||||
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
|
||||
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
|
||||
return samples
|
||||
|
||||
|
||||
@ -1,256 +1,264 @@
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .extra_samplers import uni_pc
|
||||
import torch
|
||||
import collections
|
||||
from . import model_management
|
||||
import math
|
||||
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||
strength = 1.0
|
||||
|
||||
if 'timestep_start' in conds:
|
||||
timestep_start = conds['timestep_start']
|
||||
if timestep_in[0] > timestep_start:
|
||||
return None
|
||||
if 'timestep_end' in conds:
|
||||
timestep_end = conds['timestep_end']
|
||||
if timestep_in[0] < timestep_end:
|
||||
return None
|
||||
if 'area' in conds:
|
||||
area = conds['area']
|
||||
if 'strength' in conds:
|
||||
strength = conds['strength']
|
||||
|
||||
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
if 'mask' in conds:
|
||||
# Scale the mask to the size of the input
|
||||
# The mask should have been resized as we began the sampling process
|
||||
mask_strength = 1.0
|
||||
if "mask_strength" in conds:
|
||||
mask_strength = conds["mask_strength"]
|
||||
mask = conds['mask']
|
||||
assert(mask.shape[1] == x_in.shape[2])
|
||||
assert(mask.shape[2] == x_in.shape[3])
|
||||
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
|
||||
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
||||
else:
|
||||
mask = torch.ones_like(input_x)
|
||||
mult = mask * strength
|
||||
|
||||
if 'mask' not in conds:
|
||||
rr = 8
|
||||
if area[2] != 0:
|
||||
for t in range(rr):
|
||||
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
|
||||
if (area[0] + area[2]) < x_in.shape[2]:
|
||||
for t in range(rr):
|
||||
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
|
||||
if area[3] != 0:
|
||||
for t in range(rr):
|
||||
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
|
||||
if (area[1] + area[3]) < x_in.shape[3]:
|
||||
for t in range(rr):
|
||||
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
||||
|
||||
conditioning = {}
|
||||
model_conds = conds["model_conds"]
|
||||
for c in model_conds:
|
||||
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
||||
|
||||
control = conds.get('control', None)
|
||||
|
||||
patches = None
|
||||
if 'gligen' in conds:
|
||||
gligen = conds['gligen']
|
||||
patches = {}
|
||||
gligen_type = gligen[0]
|
||||
gligen_model = gligen[1]
|
||||
if gligen_type == "position":
|
||||
gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
|
||||
else:
|
||||
gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
|
||||
|
||||
patches['middle_patch'] = [gligen_patch]
|
||||
|
||||
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches'])
|
||||
return cond_obj(input_x, mult, conditioning, area, control, patches)
|
||||
|
||||
def cond_equal_size(c1, c2):
|
||||
if c1 is c2:
|
||||
return True
|
||||
if c1.keys() != c2.keys():
|
||||
return False
|
||||
for k in c1:
|
||||
if not c1[k].can_concat(c2[k]):
|
||||
return False
|
||||
return True
|
||||
|
||||
def can_concat_cond(c1, c2):
|
||||
if c1.input_x.shape != c2.input_x.shape:
|
||||
return False
|
||||
|
||||
def objects_concatable(obj1, obj2):
|
||||
if (obj1 is None) != (obj2 is None):
|
||||
return False
|
||||
if obj1 is not None:
|
||||
if obj1 is not obj2:
|
||||
return False
|
||||
return True
|
||||
|
||||
if not objects_concatable(c1.control, c2.control):
|
||||
return False
|
||||
|
||||
if not objects_concatable(c1.patches, c2.patches):
|
||||
return False
|
||||
|
||||
return cond_equal_size(c1.conditioning, c2.conditioning)
|
||||
|
||||
def cond_cat(c_list):
|
||||
c_crossattn = []
|
||||
c_concat = []
|
||||
c_adm = []
|
||||
crossattn_max_len = 0
|
||||
|
||||
temp = {}
|
||||
for x in c_list:
|
||||
for k in x:
|
||||
cur = temp.get(k, [])
|
||||
cur.append(x[k])
|
||||
temp[k] = cur
|
||||
|
||||
out = {}
|
||||
for k in temp:
|
||||
conds = temp[k]
|
||||
out[k] = conds[0].concat(conds[1:])
|
||||
|
||||
return out
|
||||
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
out_cond = torch.zeros_like(x_in)
|
||||
out_count = torch.ones_like(x_in) * 1e-37
|
||||
|
||||
out_uncond = torch.zeros_like(x_in)
|
||||
out_uncond_count = torch.ones_like(x_in) * 1e-37
|
||||
|
||||
COND = 0
|
||||
UNCOND = 1
|
||||
|
||||
to_run = []
|
||||
for x in cond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
to_run += [(p, COND)]
|
||||
if uncond is not None:
|
||||
for x in uncond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
to_run += [(p, UNCOND)]
|
||||
|
||||
while len(to_run) > 0:
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]):
|
||||
to_batch_temp += [x]
|
||||
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
free_memory = model_management.get_free_memory(x_in.device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
area = []
|
||||
control = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = to_run.pop(x)
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
|
||||
if control is not None:
|
||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
||||
|
||||
transformer_options = {}
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = model_options['transformer_options'].copy()
|
||||
|
||||
if patches is not None:
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
else:
|
||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
del input_x
|
||||
|
||||
for o in range(batch_chunks):
|
||||
if cond_or_uncond[o] == COND:
|
||||
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
||||
else:
|
||||
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
||||
del mult
|
||||
|
||||
out_cond /= out_count
|
||||
del out_count
|
||||
out_uncond /= out_uncond_count
|
||||
del out_uncond_count
|
||||
return out_cond, out_uncond
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns denoised
|
||||
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||
strength = 1.0
|
||||
|
||||
if 'timestep_start' in conds:
|
||||
timestep_start = conds['timestep_start']
|
||||
if timestep_in[0] > timestep_start:
|
||||
return None
|
||||
if 'timestep_end' in conds:
|
||||
timestep_end = conds['timestep_end']
|
||||
if timestep_in[0] < timestep_end:
|
||||
return None
|
||||
if 'area' in conds:
|
||||
area = conds['area']
|
||||
if 'strength' in conds:
|
||||
strength = conds['strength']
|
||||
|
||||
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||
if 'mask' in conds:
|
||||
# Scale the mask to the size of the input
|
||||
# The mask should have been resized as we began the sampling process
|
||||
mask_strength = 1.0
|
||||
if "mask_strength" in conds:
|
||||
mask_strength = conds["mask_strength"]
|
||||
mask = conds['mask']
|
||||
assert(mask.shape[1] == x_in.shape[2])
|
||||
assert(mask.shape[2] == x_in.shape[3])
|
||||
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
|
||||
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
||||
else:
|
||||
mask = torch.ones_like(input_x)
|
||||
mult = mask * strength
|
||||
|
||||
if 'mask' not in conds:
|
||||
rr = 8
|
||||
if area[2] != 0:
|
||||
for t in range(rr):
|
||||
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
|
||||
if (area[0] + area[2]) < x_in.shape[2]:
|
||||
for t in range(rr):
|
||||
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
|
||||
if area[3] != 0:
|
||||
for t in range(rr):
|
||||
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
|
||||
if (area[1] + area[3]) < x_in.shape[3]:
|
||||
for t in range(rr):
|
||||
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
||||
|
||||
conditionning = {}
|
||||
model_conds = conds["model_conds"]
|
||||
for c in model_conds:
|
||||
conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
||||
|
||||
control = None
|
||||
if 'control' in conds:
|
||||
control = conds['control']
|
||||
|
||||
patches = None
|
||||
if 'gligen' in conds:
|
||||
gligen = conds['gligen']
|
||||
patches = {}
|
||||
gligen_type = gligen[0]
|
||||
gligen_model = gligen[1]
|
||||
if gligen_type == "position":
|
||||
gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
|
||||
else:
|
||||
gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
|
||||
|
||||
patches['middle_patch'] = [gligen_patch]
|
||||
|
||||
return (input_x, mult, conditionning, area, control, patches)
|
||||
|
||||
def cond_equal_size(c1, c2):
|
||||
if c1 is c2:
|
||||
return True
|
||||
if c1.keys() != c2.keys():
|
||||
return False
|
||||
for k in c1:
|
||||
if not c1[k].can_concat(c2[k]):
|
||||
return False
|
||||
return True
|
||||
|
||||
def can_concat_cond(c1, c2):
|
||||
if c1[0].shape != c2[0].shape:
|
||||
return False
|
||||
|
||||
#control
|
||||
if (c1[4] is None) != (c2[4] is None):
|
||||
return False
|
||||
if c1[4] is not None:
|
||||
if c1[4] is not c2[4]:
|
||||
return False
|
||||
|
||||
#patches
|
||||
if (c1[5] is None) != (c2[5] is None):
|
||||
return False
|
||||
if (c1[5] is not None):
|
||||
if c1[5] is not c2[5]:
|
||||
return False
|
||||
|
||||
return cond_equal_size(c1[2], c2[2])
|
||||
|
||||
def cond_cat(c_list):
|
||||
c_crossattn = []
|
||||
c_concat = []
|
||||
c_adm = []
|
||||
crossattn_max_len = 0
|
||||
|
||||
temp = {}
|
||||
for x in c_list:
|
||||
for k in x:
|
||||
cur = temp.get(k, [])
|
||||
cur.append(x[k])
|
||||
temp[k] = cur
|
||||
|
||||
out = {}
|
||||
for k in temp:
|
||||
conds = temp[k]
|
||||
out[k] = conds[0].concat(conds[1:])
|
||||
|
||||
return out
|
||||
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
out_cond = torch.zeros_like(x_in)
|
||||
out_count = torch.ones_like(x_in) * 1e-37
|
||||
|
||||
out_uncond = torch.zeros_like(x_in)
|
||||
out_uncond_count = torch.ones_like(x_in) * 1e-37
|
||||
|
||||
COND = 0
|
||||
UNCOND = 1
|
||||
|
||||
to_run = []
|
||||
for x in cond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
to_run += [(p, COND)]
|
||||
if uncond is not None:
|
||||
for x in uncond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
to_run += [(p, UNCOND)]
|
||||
|
||||
while len(to_run) > 0:
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]):
|
||||
to_batch_temp += [x]
|
||||
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
free_memory = model_management.get_free_memory(x_in.device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
area = []
|
||||
control = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = to_run.pop(x)
|
||||
p = o[0]
|
||||
input_x += [p[0]]
|
||||
mult += [p[1]]
|
||||
c += [p[2]]
|
||||
area += [p[3]]
|
||||
cond_or_uncond += [o[1]]
|
||||
control = p[4]
|
||||
patches = p[5]
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
|
||||
if control is not None:
|
||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
||||
|
||||
transformer_options = {}
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = model_options['transformer_options'].copy()
|
||||
|
||||
if patches is not None:
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
else:
|
||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
del input_x
|
||||
|
||||
for o in range(batch_chunks):
|
||||
if cond_or_uncond[o] == COND:
|
||||
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
||||
else:
|
||||
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
||||
del mult
|
||||
|
||||
out_cond /= out_count
|
||||
del out_count
|
||||
out_uncond /= out_uncond_count
|
||||
del out_uncond_count
|
||||
return out_cond, out_uncond
|
||||
|
||||
|
||||
if math.isclose(cond_scale, 1.0):
|
||||
uncond = None
|
||||
|
||||
cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
|
||||
if "sampler_cfg_function" in model_options:
|
||||
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
|
||||
return x - model_options["sampler_cfg_function"](args)
|
||||
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
|
||||
uncond_ = None
|
||||
else:
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
uncond_ = uncond
|
||||
|
||||
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
|
||||
if "sampler_cfg_function" in model_options:
|
||||
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
|
||||
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
|
||||
cfg_result = x - model_options["sampler_cfg_function"](args)
|
||||
else:
|
||||
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
||||
|
||||
for fn in model_options.get("sampler_post_cfg_function", []):
|
||||
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
|
||||
"sigma": timestep, "model_options": model_options, "input": x}
|
||||
cfg_result = fn(args)
|
||||
|
||||
return cfg_result
|
||||
|
||||
class CFGNoisePredictor(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
@ -272,10 +280,7 @@ class KSamplerX0Inpaint(torch.nn.Module):
|
||||
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
|
||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
|
||||
if denoise_mask is not None:
|
||||
out *= denoise_mask
|
||||
|
||||
if denoise_mask is not None:
|
||||
out += self.latent_image * latent_mask
|
||||
out = out * denoise_mask + self.latent_image * latent_mask
|
||||
return out
|
||||
|
||||
def simple_scheduler(model, steps):
|
||||
@ -590,6 +595,13 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
||||
calculate_start_end_timesteps(model, negative)
|
||||
calculate_start_end_timesteps(model, positive)
|
||||
|
||||
if latent_image is not None:
|
||||
latent_image = model.process_latent_in(latent_image)
|
||||
|
||||
if hasattr(model, 'extra_conds'):
|
||||
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||
|
||||
#make sure each cond area has an opposite one with the same area
|
||||
for c in positive:
|
||||
create_cond_with_same_area_if_none(negative, c)
|
||||
@ -601,13 +613,6 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
||||
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
|
||||
if latent_image is not None:
|
||||
latent_image = model.process_latent_in(latent_image)
|
||||
|
||||
if hasattr(model, 'extra_conds'):
|
||||
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
|
||||
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
|
||||
|
||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
||||
|
||||
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
@ -630,7 +635,7 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
|
||||
elif scheduler_name == "sgm_uniform":
|
||||
sigmas = normal_scheduler(model, steps, sgm=True)
|
||||
else:
|
||||
print("error invalid scheduler", self.scheduler)
|
||||
print("error invalid scheduler", scheduler_name)
|
||||
return sigmas
|
||||
|
||||
def sampler_object(name):
|
||||
|
||||
48
comfy/sd.py
48
comfy/sd.py
@ -148,12 +148,14 @@ class CLIP:
|
||||
return self.patcher.get_key_patches()
|
||||
|
||||
class VAE:
|
||||
def __init__(self, sd=None, device=None, config=None):
|
||||
def __init__(self, sd=None, device=None, config=None, dtype=None):
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
||||
self.downscale_ratio = 8
|
||||
self.latent_channels = 4
|
||||
|
||||
if config is None:
|
||||
if "decoder.mid.block_1.mix_factor" in sd:
|
||||
@ -169,6 +171,11 @@ class VAE:
|
||||
else:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
|
||||
if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
|
||||
ddconfig['ch_mult'] = [1, 2, 4]
|
||||
self.downscale_ratio = 4
|
||||
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
||||
else:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
@ -185,8 +192,11 @@ class VAE:
|
||||
device = model_management.vae_device()
|
||||
self.device = device
|
||||
offload_device = model_management.vae_offload_device()
|
||||
self.vae_dtype = model_management.vae_dtype()
|
||||
if dtype is None:
|
||||
dtype = model_management.vae_dtype()
|
||||
self.vae_dtype = dtype
|
||||
self.first_stage_model.to(self.vae_dtype)
|
||||
self.output_device = model_management.intermediate_device()
|
||||
|
||||
self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
|
||||
@ -198,9 +208,9 @@ class VAE:
|
||||
|
||||
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
|
||||
output = torch.clamp((
|
||||
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar))
|
||||
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar))
|
||||
/ 3.0) / 2.0, min=0.0, max=1.0)
|
||||
return output
|
||||
|
||||
@ -211,9 +221,9 @@ class VAE:
|
||||
pbar = utils.ProgressBar(steps)
|
||||
|
||||
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
|
||||
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples /= 3.0
|
||||
return samples
|
||||
|
||||
@ -225,15 +235,15 @@ class VAE:
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
|
||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device)
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
|
||||
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
@ -249,10 +259,10 @@ class VAE:
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
|
||||
samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
|
||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
|
||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
@ -429,11 +439,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
|
||||
parameters = utils.calculate_parameters(sd, "model.diffusion_model.")
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
load_device = model_management.get_torch_device()
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
|
||||
model_config.set_manual_cast(manual_cast_dtype)
|
||||
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
|
||||
@ -466,7 +480,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
print("left over keys:", left_over)
|
||||
|
||||
if output_model:
|
||||
_model_patcher = model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
||||
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
print("loaded straight to GPU")
|
||||
model_management.load_model_gpu(model_patcher)
|
||||
@ -477,6 +491,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
def load_unet_state_dict(sd): #load unet in diffusers format
|
||||
parameters = utils.calculate_parameters(sd)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
load_device = model_management.get_torch_device()
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
if "input_blocks.0.0.weight" in sd: #ldm
|
||||
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
|
||||
if model_config is None:
|
||||
@ -497,13 +514,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format
|
||||
else:
|
||||
print(diffusers_keys[k], k)
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model_config.set_manual_cast(manual_cast_dtype)
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
left_over = sd.keys()
|
||||
if len(left_over) > 0:
|
||||
print("left over keys in unet:", left_over)
|
||||
return model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
||||
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
def load_unet(unet_path):
|
||||
sd = utils.load_torch_file(unet_path)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils
|
||||
from transformers import CLIPTokenizer
|
||||
from . import ops
|
||||
import torch
|
||||
import traceback
|
||||
@ -8,6 +8,8 @@ import zipfile
|
||||
from . import model_management
|
||||
from pkg_resources import resource_filename
|
||||
import contextlib
|
||||
from . import clip_model
|
||||
import json
|
||||
|
||||
def gen_empty_tokens(special_tokens, length):
|
||||
start_token = special_tokens.get("start", None)
|
||||
@ -38,7 +40,7 @@ class ClipTokenWeightEncoder:
|
||||
|
||||
out, pooled = self.encode(to_encode)
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].cpu()
|
||||
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
||||
else:
|
||||
first_pooled = pooled
|
||||
|
||||
@ -55,8 +57,8 @@ class ClipTokenWeightEncoder:
|
||||
output.append(z)
|
||||
|
||||
if (len(output) == 0):
|
||||
return out[-1:].cpu(), first_pooled
|
||||
return torch.cat(output, dim=-2).cpu(), first_pooled
|
||||
return out[-1:].to(model_management.intermediate_device()), first_pooled
|
||||
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
|
||||
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
@ -66,33 +68,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"hidden"
|
||||
]
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig,
|
||||
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=clip_model.CLIPTextModel,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.num_layers = 12
|
||||
if textmodel_path is not None:
|
||||
self.transformer = model_class.from_pretrained(textmodel_path)
|
||||
else:
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||
if not os.path.exists(textmodel_json_config):
|
||||
textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json')
|
||||
config = config_class.from_json_file(textmodel_json_config)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
with ops.use_comfy_ops(device, dtype):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.transformer = model_class(config)
|
||||
|
||||
self.inner_name = inner_name
|
||||
if dtype is not None:
|
||||
self.transformer.to(dtype)
|
||||
inner_model = getattr(self.transformer, self.inner_name)
|
||||
if hasattr(inner_model, "embeddings"):
|
||||
inner_model.embeddings.to(torch.float32)
|
||||
else:
|
||||
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||
if not os.path.exists(textmodel_json_config):
|
||||
textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json')
|
||||
|
||||
with open(textmodel_json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.transformer = model_class(config, dtype, device, ops.manual_cast)
|
||||
self.num_layers = self.transformer.num_layers
|
||||
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
@ -107,7 +97,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert abs(layer_idx) <= self.num_layers
|
||||
assert abs(layer_idx) < self.num_layers
|
||||
self.clip_layer(layer_idx)
|
||||
self.layer_default = (self.layer, self.layer_idx)
|
||||
|
||||
@ -118,7 +108,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
param.requires_grad = False
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
if abs(layer_idx) >= self.num_layers:
|
||||
if abs(layer_idx) > self.num_layers:
|
||||
self.layer = "last"
|
||||
else:
|
||||
self.layer = "hidden"
|
||||
@ -173,41 +163,31 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||
tokens = torch.LongTensor(tokens).to(device)
|
||||
|
||||
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
|
||||
precision_scope = torch.autocast
|
||||
attention_mask = None
|
||||
if self.enable_attention_masks:
|
||||
attention_mask = torch.zeros_like(tokens)
|
||||
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
|
||||
for x in range(attention_mask.shape[0]):
|
||||
for y in range(attention_mask.shape[1]):
|
||||
attention_mask[x, y] = 1
|
||||
if tokens[x, y] == max_token:
|
||||
break
|
||||
|
||||
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs[0]
|
||||
else:
|
||||
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
|
||||
z = outputs[1]
|
||||
|
||||
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
|
||||
attention_mask = None
|
||||
if self.enable_attention_masks:
|
||||
attention_mask = torch.zeros_like(tokens)
|
||||
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
|
||||
for x in range(attention_mask.shape[0]):
|
||||
for y in range(attention_mask.shape[1]):
|
||||
attention_mask[x, y] = 1
|
||||
if tokens[x, y] == max_token:
|
||||
break
|
||||
if outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
else:
|
||||
pooled_output = None
|
||||
|
||||
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden")
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
z = outputs.pooler_output[:, None, :]
|
||||
else:
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
if self.layer_norm_hidden_state:
|
||||
z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
|
||||
|
||||
if hasattr(outputs, "pooler_output"):
|
||||
pooled_output = outputs.pooler_output.float()
|
||||
else:
|
||||
pooled_output = None
|
||||
|
||||
if self.text_projection is not None and pooled_output is not None:
|
||||
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
||||
if self.text_projection is not None and pooled_output is not None:
|
||||
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
||||
return z.float(), pooled_output
|
||||
|
||||
def encode(self, tokens):
|
||||
|
||||
@ -4,15 +4,15 @@ from . import sd1_clip
|
||||
import os
|
||||
|
||||
class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
|
||||
if layer == "penultimate":
|
||||
layer="hidden"
|
||||
layer_idx=23
|
||||
layer_idx=-2
|
||||
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
||||
if not os.path.exists(textmodel_json_config):
|
||||
textmodel_json_config = resource_filename('comfy', 'sd2_clip_config.json')
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
||||
|
||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
|
||||
@ -3,13 +3,13 @@ import torch
|
||||
import os
|
||||
|
||||
class SDXLClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
|
||||
if layer == "penultimate":
|
||||
layer="hidden"
|
||||
layer_idx=-2
|
||||
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype,
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
|
||||
|
||||
def load_sd(self, sd):
|
||||
@ -37,7 +37,7 @@ class SDXLTokenizer:
|
||||
class SDXLClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
|
||||
@ -217,6 +217,16 @@ class SSD1B(SDXL):
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
class Segmind_Vega(SDXL):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [0, 0, 1, 1, 2, 2],
|
||||
"context_dim": 2048,
|
||||
"adm_in_channels": 2816,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
class SVD_img2vid(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
@ -242,5 +252,59 @@ class SVD_img2vid(supported_models_base.BASE):
|
||||
def clip_target(self):
|
||||
return None
|
||||
|
||||
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]
|
||||
class Stable_Zero123(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"context_dim": 768,
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": False,
|
||||
"adm_in_channels": None,
|
||||
"use_temporal_attention": False,
|
||||
"in_channels": 8,
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
"num_heads": 8,
|
||||
"num_head_channels": -1,
|
||||
}
|
||||
|
||||
clip_vision_prefix = "cond_stage_model.model.visual."
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
|
||||
return out
|
||||
|
||||
def clip_target(self):
|
||||
return None
|
||||
|
||||
class SD_X4Upscaler(SD20):
|
||||
unet_config = {
|
||||
"context_dim": 1024,
|
||||
"model_channels": 256,
|
||||
'in_channels': 7,
|
||||
"use_linear_in_transformer": True,
|
||||
"adm_in_channels": None,
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
"disable_self_attentions": [True, True, True, False],
|
||||
"num_classes": 1000,
|
||||
"num_heads": 8,
|
||||
"num_head_channels": -1,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SD_X4
|
||||
|
||||
sampling_settings = {
|
||||
"linear_start": 0.0001,
|
||||
"linear_end": 0.02,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SD_X4Upscaler(self, device=device)
|
||||
return out
|
||||
|
||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler]
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@ -22,6 +22,8 @@ class BASE:
|
||||
sampling_settings = {}
|
||||
latent_format = latent_formats.LatentFormat
|
||||
|
||||
manual_cast_dtype = None
|
||||
|
||||
@classmethod
|
||||
def matches(s, unet_config):
|
||||
for k in s.unet_config:
|
||||
@ -71,3 +73,5 @@ class BASE:
|
||||
replace_prefix = {"": "first_stage_model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def set_manual_cast(self, manual_cast_dtype):
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
|
||||
@ -7,9 +7,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .. import utils
|
||||
from .. import ops
|
||||
|
||||
def conv(n_in, n_out, **kwargs):
|
||||
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||
return ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||
|
||||
class Clamp(nn.Module):
|
||||
def forward(self, x):
|
||||
@ -19,7 +20,7 @@ class Block(nn.Module):
|
||||
def __init__(self, n_in, n_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
|
||||
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||
self.skip = ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||
self.fuse = nn.ReLU()
|
||||
def forward(self, x):
|
||||
return self.fuse(self.conv(x) + self.skip(x))
|
||||
|
||||
@ -378,7 +378,7 @@ def lanczos(samples, width, height):
|
||||
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
||||
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
|
||||
result = torch.stack(images)
|
||||
return result
|
||||
return result.to(samples.device, samples.dtype)
|
||||
|
||||
def common_upscale(samples, width, height, upscale_method, crop):
|
||||
if crop == "center":
|
||||
@ -407,17 +407,17 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
||||
|
||||
@torch.inference_mode()
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None):
|
||||
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu")
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device)
|
||||
for b in range(samples.shape[0]):
|
||||
s = samples[b:b+1]
|
||||
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
|
||||
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
|
||||
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
|
||||
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
|
||||
for y in range(0, s.shape[2], tile_y - overlap):
|
||||
for x in range(0, s.shape[3], tile_x - overlap):
|
||||
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
|
||||
|
||||
ps = function(s_in).cpu()
|
||||
ps = function(s_in).to(output_device)
|
||||
mask = torch.ones_like(ps)
|
||||
feather = round(overlap * upscale_amount)
|
||||
for t in range(feather):
|
||||
|
||||
@ -291,7 +291,7 @@ class Canny:
|
||||
|
||||
def detect_edge(self, image, low_threshold, high_threshold):
|
||||
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
|
||||
img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1)
|
||||
img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
|
||||
return (img_out,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import comfy.samplers
|
||||
import comfy.sample
|
||||
from comfy import samplers
|
||||
from comfy import sample
|
||||
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||
from comfy.cmd import latent_preview
|
||||
import torch
|
||||
import comfy.utils
|
||||
from comfy import utils
|
||||
|
||||
|
||||
class BasicScheduler:
|
||||
@ -11,8 +11,9 @@ class BasicScheduler:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"model": ("MODEL",),
|
||||
"scheduler": (comfy.samplers.SCHEDULER_NAMES, ),
|
||||
"scheduler": (samplers.SCHEDULER_NAMES, ),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
@ -20,8 +21,15 @@ class BasicScheduler:
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
def get_sigmas(self, model, scheduler, steps):
|
||||
sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, steps).cpu()
|
||||
def get_sigmas(self, model, scheduler, steps, denoise):
|
||||
total_steps = steps
|
||||
if denoise < 1.0:
|
||||
total_steps = int(steps/denoise)
|
||||
|
||||
inner_model = model.patch_model(patch_weights=False)
|
||||
sigmas = samplers.calculate_sigmas_scheduler(inner_model, scheduler, total_steps).cpu()
|
||||
model.unpatch_model()
|
||||
sigmas = sigmas[-(steps + 1):]
|
||||
return (sigmas, )
|
||||
|
||||
|
||||
@ -87,6 +95,7 @@ class SDTurboScheduler:
|
||||
return {"required":
|
||||
{"model": ("MODEL",),
|
||||
"steps": ("INT", {"default": 1, "min": 1, "max": 10}),
|
||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SIGMAS",)
|
||||
@ -94,9 +103,12 @@ class SDTurboScheduler:
|
||||
|
||||
FUNCTION = "get_sigmas"
|
||||
|
||||
def get_sigmas(self, model, steps):
|
||||
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps]
|
||||
sigmas = model.model.model_sampling.sigma(timesteps)
|
||||
def get_sigmas(self, model, steps, denoise):
|
||||
start_step = 10 - int(10 * denoise)
|
||||
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
|
||||
inner_model = model.patch_model(patch_weights=False)
|
||||
sigmas = inner_model.model_sampling.sigma(timesteps)
|
||||
model.unpatch_model()
|
||||
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
||||
return (sigmas, )
|
||||
|
||||
@ -159,7 +171,7 @@ class KSamplerSelect:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"sampler_name": (comfy.samplers.SAMPLER_NAMES, ),
|
||||
{"sampler_name": (samplers.SAMPLER_NAMES, ),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
@ -168,7 +180,7 @@ class KSamplerSelect:
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, sampler_name):
|
||||
sampler = comfy.samplers.sampler_object(sampler_name)
|
||||
sampler = samplers.sampler_object(sampler_name)
|
||||
return (sampler, )
|
||||
|
||||
class SamplerDPMPP_2M_SDE:
|
||||
@ -191,7 +203,7 @@ class SamplerDPMPP_2M_SDE:
|
||||
sampler_name = "dpmpp_2m_sde"
|
||||
else:
|
||||
sampler_name = "dpmpp_2m_sde_gpu"
|
||||
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
|
||||
sampler = samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
|
||||
return (sampler, )
|
||||
|
||||
|
||||
@ -215,7 +227,7 @@ class SamplerDPMPP_SDE:
|
||||
sampler_name = "dpmpp_sde"
|
||||
else:
|
||||
sampler_name = "dpmpp_sde_gpu"
|
||||
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
|
||||
sampler = samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
|
||||
return (sampler, )
|
||||
|
||||
class SamplerCustom:
|
||||
@ -248,7 +260,7 @@ class SamplerCustom:
|
||||
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||
else:
|
||||
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
||||
noise = comfy.sample.prepare_noise(latent_image, noise_seed, batch_inds)
|
||||
noise = sample.prepare_noise(latent_image, noise_seed, batch_inds)
|
||||
|
||||
noise_mask = None
|
||||
if "noise_mask" in latent:
|
||||
@ -257,8 +269,8 @@ class SamplerCustom:
|
||||
x0_output = {}
|
||||
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
|
||||
|
||||
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
||||
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
|
||||
disable_pbar = not utils.PROGRESS_BAR_ENABLED
|
||||
samples = sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
|
||||
|
||||
out = latent.copy()
|
||||
out["samples"] = samples
|
||||
|
||||
@ -2,9 +2,10 @@
|
||||
|
||||
import math
|
||||
from einops import rearrange
|
||||
import random
|
||||
# Use torch rng for consistency across generations
|
||||
from torch import randint
|
||||
|
||||
def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int:
|
||||
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
||||
min_value = min(min_value, value)
|
||||
|
||||
# All big divisors of value (inclusive)
|
||||
@ -12,8 +13,10 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter
|
||||
|
||||
ns = [value // i for i in divisors[:max_options]] # has at least 1 element
|
||||
|
||||
random.seed(counter)
|
||||
idx = random.randint(0, len(ns) - 1)
|
||||
if len(ns) - 1 > 0:
|
||||
idx = randint(low=0, high=len(ns) - 1, size=(1,)).item()
|
||||
else:
|
||||
idx = 0
|
||||
|
||||
return ns[idx]
|
||||
|
||||
@ -42,7 +45,6 @@ class HyperTile:
|
||||
|
||||
latent_tile_size = max(32, tile_size) // 8
|
||||
self.temp = None
|
||||
self.counter = 1
|
||||
|
||||
def hypertile_in(q, k, v, extra_options):
|
||||
if q.shape[-1] in apply_to:
|
||||
@ -53,10 +55,8 @@ class HyperTile:
|
||||
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
|
||||
|
||||
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1
|
||||
nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter)
|
||||
self.counter += 1
|
||||
nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter)
|
||||
self.counter += 1
|
||||
nh = random_divisor(h, latent_tile_size * factor, swap_size)
|
||||
nw = random_divisor(w, latent_tile_size * factor, swap_size)
|
||||
|
||||
if nh * nw > 1:
|
||||
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
||||
|
||||
@ -73,7 +73,7 @@ class SaveAnimatedWEBP:
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "image/animation"
|
||||
|
||||
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
|
||||
method = self.methods.get(method)
|
||||
@ -135,7 +135,7 @@ class SaveAnimatedPNG:
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "image/animation"
|
||||
|
||||
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
|
||||
@ -3,9 +3,7 @@ import torch
|
||||
|
||||
def reshape_latent_to(target_shape, latent):
|
||||
if latent.shape[1:] != target_shape[1:]:
|
||||
latent.movedim(1, -1)
|
||||
latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center")
|
||||
latent.movedim(-1, 1)
|
||||
return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
|
||||
|
||||
|
||||
@ -102,9 +100,32 @@ class LatentInterpolate:
|
||||
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
||||
return (samples_out,)
|
||||
|
||||
class LatentBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "batch"
|
||||
|
||||
CATEGORY = "latent/batch"
|
||||
|
||||
def batch(self, samples1, samples2):
|
||||
samples_out = samples1.copy()
|
||||
s1 = samples1["samples"]
|
||||
s2 = samples2["samples"]
|
||||
|
||||
if s1.shape[1:] != s2.shape[1:]:
|
||||
s2 = comfy.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center")
|
||||
s = torch.cat((s1, s2), dim=0)
|
||||
samples_out["samples"] = s
|
||||
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
|
||||
return (samples_out,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LatentAdd": LatentAdd,
|
||||
"LatentSubtract": LatentSubtract,
|
||||
"LatentMultiply": LatentMultiply,
|
||||
"LatentInterpolate": LatentInterpolate,
|
||||
"LatentBatch": LatentBatch,
|
||||
}
|
||||
|
||||
@ -7,6 +7,7 @@ from comfy.nodes.common import MAX_RESOLUTION
|
||||
|
||||
|
||||
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
||||
source = source.to(destination.device)
|
||||
if resize_source:
|
||||
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
|
||||
|
||||
@ -21,7 +22,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
|
||||
if mask is None:
|
||||
mask = torch.ones_like(source)
|
||||
else:
|
||||
mask = mask.clone()
|
||||
mask = mask.to(destination.device, copy=True)
|
||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
|
||||
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])
|
||||
|
||||
|
||||
@ -16,41 +16,19 @@ class LCM(comfy.model_sampling.EPS):
|
||||
|
||||
return c_out * x0 + c_skip * model_input
|
||||
|
||||
class ModelSamplingDiscreteDistilled(torch.nn.Module):
|
||||
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
|
||||
original_timesteps = 50
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sigma_data = 1.0
|
||||
timesteps = 1000
|
||||
beta_start = 0.00085
|
||||
beta_end = 0.012
|
||||
def __init__(self, model_config=None):
|
||||
super().__init__(model_config)
|
||||
|
||||
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
self.skip_steps = self.num_timesteps // self.original_timesteps
|
||||
|
||||
self.skip_steps = timesteps // self.original_timesteps
|
||||
|
||||
|
||||
alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
|
||||
sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
|
||||
for x in range(self.original_timesteps):
|
||||
alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
|
||||
sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps]
|
||||
|
||||
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
|
||||
self.set_sigmas(sigmas)
|
||||
|
||||
def set_sigmas(self, sigmas):
|
||||
self.register_buffer('sigmas', sigmas)
|
||||
self.register_buffer('log_sigmas', sigmas.log())
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
self.set_sigmas(sigmas_valid)
|
||||
|
||||
def timestep(self, sigma):
|
||||
log_sigma = sigma.log()
|
||||
@ -65,14 +43,6 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module):
|
||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||
return log_sigma.exp().to(timestep.device)
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
return 999999999.9
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
percent = 1.0 - percent
|
||||
return self.sigma(torch.tensor(percent * 999.0)).item()
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||
@ -121,7 +91,7 @@ class ModelSamplingDiscrete:
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced()
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
if zsnr:
|
||||
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
|
||||
|
||||
@ -153,7 +123,7 @@ class ModelSamplingContinuousEDM:
|
||||
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced()
|
||||
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||
model_sampling.set_sigma_range(sigma_min, sigma_max)
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
||||
53
comfy_extras/nodes/nodes_perpneg.py
Normal file
53
comfy_extras/nodes/nodes_perpneg.py
Normal file
@ -0,0 +1,53 @@
|
||||
import torch
|
||||
from comfy import sample
|
||||
from comfy import samplers
|
||||
|
||||
|
||||
class PerpNeg:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model": ("MODEL", ),
|
||||
"empty_conditioning": ("CONDITIONING", ),
|
||||
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def patch(self, model, empty_conditioning, neg_scale):
|
||||
m = model.clone()
|
||||
nocond = sample.convert_cond(empty_conditioning)
|
||||
|
||||
def cfg_function(args):
|
||||
model = args["model"]
|
||||
noise_pred_pos = args["cond_denoised"]
|
||||
noise_pred_neg = args["uncond_denoised"]
|
||||
cond_scale = args["cond_scale"]
|
||||
x = args["input"]
|
||||
sigma = args["sigma"]
|
||||
model_options = args["model_options"]
|
||||
nocond_processed = samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
|
||||
|
||||
(noise_pred_nocond, _) = samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options)
|
||||
|
||||
pos = noise_pred_pos - noise_pred_nocond
|
||||
neg = noise_pred_neg - noise_pred_nocond
|
||||
perp = ((torch.mul(pos, neg).sum())/(torch.norm(neg)**2)) * neg
|
||||
perp_neg = perp * neg_scale
|
||||
cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg)
|
||||
cfg_result = x - cfg_result
|
||||
return cfg_result
|
||||
|
||||
m.set_model_sampler_cfg_function(cfg_function)
|
||||
|
||||
return (m, )
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PerpNeg": PerpNeg,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PerpNeg": "Perp-Neg",
|
||||
}
|
||||
@ -226,7 +226,7 @@ class Sharpen:
|
||||
batch_size, height, width, channels = image.shape
|
||||
|
||||
kernel_size = sharpen_radius * 2 + 1
|
||||
kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10)
|
||||
kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
|
||||
center = kernel_size // 2
|
||||
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
|
||||
|
||||
@ -99,10 +99,40 @@ class LatentRebatch:
|
||||
|
||||
return (output_list,)
|
||||
|
||||
class ImageRebatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "images": ("IMAGE",),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
INPUT_IS_LIST = True
|
||||
OUTPUT_IS_LIST = (True, )
|
||||
|
||||
FUNCTION = "rebatch"
|
||||
|
||||
CATEGORY = "image/batch"
|
||||
|
||||
def rebatch(self, images, batch_size):
|
||||
batch_size = batch_size[0]
|
||||
|
||||
output_list = []
|
||||
all_images = []
|
||||
for img in images:
|
||||
for i in range(img.shape[0]):
|
||||
all_images.append(img[i:i+1])
|
||||
|
||||
for i in range(0, len(all_images), batch_size):
|
||||
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
|
||||
|
||||
return (output_list,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"RebatchLatents": LatentRebatch,
|
||||
"RebatchImages": ImageRebatch,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"RebatchLatents": "Rebatch Latents",
|
||||
}
|
||||
"RebatchImages": "Rebatch Images",
|
||||
}
|
||||
|
||||
168
comfy_extras/nodes/nodes_sag.py
Normal file
168
comfy_extras/nodes/nodes_sag.py
Normal file
@ -0,0 +1,168 @@
|
||||
import torch
|
||||
from torch import einsum
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
from einops import rearrange, repeat
|
||||
import os
|
||||
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
|
||||
from comfy import samplers
|
||||
|
||||
# from comfy/ldm/modules/attention.py
|
||||
# but modified to return attention scores as well as output
|
||||
def attention_basic_with_sim(q, k, v, heads, mask=None):
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
scale = dim_head ** -0.5
|
||||
|
||||
h = heads
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, -1, heads, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, -1, dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
||||
else:
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||
|
||||
del q, k
|
||||
|
||||
if mask is not None:
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return (out, sim)
|
||||
|
||||
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
|
||||
# reshape and GAP the attention map
|
||||
_, hw1, hw2 = attn.shape
|
||||
b, _, lh, lw = x0.shape
|
||||
attn = attn.reshape(b, -1, hw1, hw2)
|
||||
# Global Average Pool
|
||||
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
||||
ratio = math.ceil(math.sqrt(lh * lw / hw1))
|
||||
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
|
||||
|
||||
# Reshape
|
||||
mask = (
|
||||
mask.reshape(b, *mid_shape)
|
||||
.unsqueeze(1)
|
||||
.type(attn.dtype)
|
||||
)
|
||||
# Upsample
|
||||
mask = F.interpolate(mask, (lh, lw))
|
||||
|
||||
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
|
||||
blurred = blurred * mask + x0 * (1 - mask)
|
||||
return blurred
|
||||
|
||||
def gaussian_blur_2d(img, kernel_size, sigma):
|
||||
ksize_half = (kernel_size - 1) * 0.5
|
||||
|
||||
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
||||
|
||||
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
||||
|
||||
x_kernel = pdf / pdf.sum()
|
||||
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
|
||||
|
||||
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
|
||||
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
|
||||
|
||||
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
|
||||
|
||||
img = F.pad(img, padding, mode="reflect")
|
||||
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
|
||||
return img
|
||||
|
||||
class SelfAttentionGuidance:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}),
|
||||
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def patch(self, model, scale, blur_sigma):
|
||||
m = model.clone()
|
||||
|
||||
attn_scores = None
|
||||
|
||||
# TODO: make this work properly with chunked batches
|
||||
# currently, we can only save the attn from one UNet call
|
||||
def attn_and_record(q, k, v, extra_options):
|
||||
nonlocal attn_scores
|
||||
# if uncond, save the attention scores
|
||||
heads = extra_options["n_heads"]
|
||||
cond_or_uncond = extra_options["cond_or_uncond"]
|
||||
b = q.shape[0] // len(cond_or_uncond)
|
||||
if 1 in cond_or_uncond:
|
||||
uncond_index = cond_or_uncond.index(1)
|
||||
# do the entire attention operation, but save the attention scores to attn_scores
|
||||
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads)
|
||||
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
|
||||
n_slices = heads * b
|
||||
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
|
||||
return out
|
||||
else:
|
||||
return optimized_attention(q, k, v, heads=heads)
|
||||
|
||||
def post_cfg_function(args):
|
||||
nonlocal attn_scores
|
||||
uncond_attn = attn_scores
|
||||
|
||||
sag_scale = scale
|
||||
sag_sigma = blur_sigma
|
||||
sag_threshold = 1.0
|
||||
model = args["model"]
|
||||
uncond_pred = args["uncond_denoised"]
|
||||
uncond = args["uncond"]
|
||||
cfg_result = args["denoised"]
|
||||
sigma = args["sigma"]
|
||||
model_options = args["model_options"]
|
||||
x = args["input"]
|
||||
|
||||
# create the adversarially blurred image
|
||||
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
|
||||
degraded_noised = degraded + x - uncond_pred
|
||||
# call into the UNet
|
||||
(sag, _) = samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
|
||||
return cfg_result + (degraded - sag) * sag_scale
|
||||
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
# from diffusers:
|
||||
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
|
||||
m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
|
||||
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SelfAttentionGuidance": SelfAttentionGuidance,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"SelfAttentionGuidance": "Self-Attention Guidance",
|
||||
}
|
||||
46
comfy_extras/nodes/nodes_sdupscale.py
Normal file
46
comfy_extras/nodes/nodes_sdupscale.py
Normal file
@ -0,0 +1,46 @@
|
||||
import torch
|
||||
from comfy import utils
|
||||
|
||||
class SD_4XUpscale_Conditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "images": ("IMAGE",),
|
||||
"positive": ("CONDITIONING",),
|
||||
"negative": ("CONDITIONING",),
|
||||
"scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/upscale_diffusion"
|
||||
|
||||
def encode(self, images, positive, negative, scale_ratio, noise_augmentation):
|
||||
width = max(1, round(images.shape[-2] * scale_ratio))
|
||||
height = max(1, round(images.shape[-3] * scale_ratio))
|
||||
|
||||
pixels = utils.common_upscale((images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center")
|
||||
|
||||
out_cp = []
|
||||
out_cn = []
|
||||
|
||||
for t in positive:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['concat_image'] = pixels
|
||||
n[1]['noise_augmentation'] = noise_augmentation
|
||||
out_cp.append(n)
|
||||
|
||||
for t in negative:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['concat_image'] = pixels
|
||||
n[1]['noise_augmentation'] = noise_augmentation
|
||||
out_cn.append(n)
|
||||
|
||||
latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
|
||||
return (out_cp, out_cn, {"samples":latent})
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning,
|
||||
}
|
||||
61
comfy_extras/nodes/nodes_stable3d.py
Normal file
61
comfy_extras/nodes/nodes_stable3d.py
Normal file
@ -0,0 +1,61 @@
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
from comfy.nodes.common import MAX_RESOLUTION
|
||||
from comfy import utils
|
||||
|
||||
|
||||
def camera_embeddings(elevation, azimuth):
|
||||
elevation = torch.as_tensor([elevation])
|
||||
azimuth = torch.as_tensor([azimuth])
|
||||
embeddings = torch.stack(
|
||||
[
|
||||
torch.deg2rad(
|
||||
(90 - elevation) - (90)
|
||||
), # Zero123 polar is 90-elevation
|
||||
torch.sin(torch.deg2rad(azimuth)),
|
||||
torch.cos(torch.deg2rad(azimuth)),
|
||||
torch.deg2rad(
|
||||
90 - torch.full_like(elevation, 0)
|
||||
),
|
||||
], dim=-1).unsqueeze(1)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class StableZero123_Conditioning:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_vision": ("CLIP_VISION",),
|
||||
"init_image": ("IMAGE",),
|
||||
"vae": ("VAE",),
|
||||
"width": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"height": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||
RETURN_NAMES = ("positive", "negative", "latent")
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/3d_models"
|
||||
|
||||
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth):
|
||||
output = clip_vision.encode_image(init_image)
|
||||
pooled = output.image_embeds.unsqueeze(0)
|
||||
pixels = utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
|
||||
encode_pixels = pixels[:,:,:,:3]
|
||||
t = vae.encode(encode_pixels)
|
||||
cam_embeds = camera_embeddings(elevation, azimuth)
|
||||
cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1)
|
||||
|
||||
positive = [[cond, {"concat_latent_image": t}]]
|
||||
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
|
||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
|
||||
return (positive, negative, {"samples":latent})
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"StableZero123_Conditioning": StableZero123_Conditioning,
|
||||
}
|
||||
@ -3,6 +3,7 @@ torchaudio
|
||||
torchvision
|
||||
torchdiffeq>=0.2.3
|
||||
torchsde>=0.2.6
|
||||
torchvision
|
||||
einops>=0.6.0
|
||||
open-clip-torch>=2.16.0
|
||||
transformers>=4.29.1
|
||||
|
||||
19
setup.py
19
setup.py
@ -28,18 +28,18 @@ version = '0.0.1'
|
||||
"""
|
||||
The package index to the torch built with AMD ROCm.
|
||||
"""
|
||||
amd_torch_index = "https://download.pytorch.org/whl/rocm5.6"
|
||||
amd_torch_index = ("https://download.pytorch.org/whl/rocm5.6", "https://download.pytorch.org/whl/nightly/rocm5.7")
|
||||
|
||||
"""
|
||||
The package index to torch built with CUDA.
|
||||
Observe the CUDA version is in this URL.
|
||||
"""
|
||||
nvidia_torch_index = "https://download.pytorch.org/whl/cu121"
|
||||
nvidia_torch_index = ("https://download.pytorch.org/whl/cu121", "https://download.pytorch.org/whl/nightly/cu121")
|
||||
|
||||
"""
|
||||
The package index to torch built against CPU features.
|
||||
"""
|
||||
cpu_torch_index = "https://download.pytorch.org/whl/cpu"
|
||||
cpu_torch_index = ("https://download.pytorch.org/whl/cpu", "https://download.pytorch.org/whl/nightly/cpu")
|
||||
|
||||
# xformers not required for new torch
|
||||
|
||||
@ -102,11 +102,11 @@ def _is_linux_arm64():
|
||||
|
||||
def dependencies() -> List[str]:
|
||||
_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements.txt")).readlines()
|
||||
# todo: also add all plugin dependencies
|
||||
_alternative_indices = [amd_torch_index, nvidia_torch_index]
|
||||
session = PipSession()
|
||||
|
||||
index_urls = ['https://pypi.org/simple']
|
||||
# (stable, nightly) tuple
|
||||
index_urls = [('https://pypi.org/simple', 'https://pypi.org/simple')]
|
||||
# prefer nvidia over AMD because AM5/iGPU systems will have a valid ROCm device
|
||||
if _is_nvidia():
|
||||
index_urls += [nvidia_torch_index]
|
||||
@ -118,6 +118,13 @@ def dependencies() -> List[str]:
|
||||
if len(index_urls) == 1:
|
||||
return _dependencies
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
# use the nightlies
|
||||
index_urls = [nightly for (_, nightly) in index_urls]
|
||||
_alternative_indices = [nightly for (_, nightly) in _alternative_indices]
|
||||
else:
|
||||
index_urls = [stable for (stable, _) in index_urls]
|
||||
_alternative_indices = [stable for (stable, _) in _alternative_indices]
|
||||
try:
|
||||
# pip 23
|
||||
finder = PackageFinder.create(LinkCollector(session, SearchScope([], index_urls, no_index=False)),
|
||||
@ -149,7 +156,7 @@ setup(
|
||||
description="",
|
||||
author="",
|
||||
version=version,
|
||||
python_requires=">=3.9,<3.12",
|
||||
python_requires=">=3.9,<3.13",
|
||||
# todo: figure out how to include the web directory to eventually let main live inside the package
|
||||
# todo: see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ for more about adding plugins
|
||||
packages=find_packages(exclude=[] if is_editable else ['custom_nodes']),
|
||||
|
||||
9
tests-ui/afterSetup.js
Normal file
9
tests-ui/afterSetup.js
Normal file
@ -0,0 +1,9 @@
|
||||
const { start } = require("./utils");
|
||||
const lg = require("./utils/litegraph");
|
||||
|
||||
// Load things once per test file before to ensure its all warmed up for the tests
|
||||
beforeAll(async () => {
|
||||
lg.setup(global);
|
||||
await start({ resetEnv: true });
|
||||
lg.teardown(global);
|
||||
});
|
||||
@ -2,8 +2,10 @@
|
||||
const config = {
|
||||
testEnvironment: "jsdom",
|
||||
setupFiles: ["./globalSetup.js"],
|
||||
setupFilesAfterEnv: ["./afterSetup.js"],
|
||||
clearMocks: true,
|
||||
resetModules: true,
|
||||
testTimeout: 10000
|
||||
};
|
||||
|
||||
module.exports = config;
|
||||
|
||||
@ -52,7 +52,7 @@ describe("extensions", () => {
|
||||
const nodeNames = Object.keys(defs);
|
||||
const nodeCount = nodeNames.length;
|
||||
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
|
||||
for (let i = 0; i < nodeCount; i++) {
|
||||
for (let i = 0; i < 10; i++) {
|
||||
// It should be send the JS class and the original JSON definition
|
||||
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0];
|
||||
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1];
|
||||
@ -133,7 +133,7 @@ describe("extensions", () => {
|
||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2);
|
||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1);
|
||||
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
}, 15000);
|
||||
|
||||
it("allows custom nodeDefs and widgets to be registered", async () => {
|
||||
const widgetMock = jest.fn((node, inputName, inputData, app) => {
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
// @ts-check
|
||||
/// <reference path="../node_modules/@types/jest/index.d.ts" />
|
||||
|
||||
const { start, createDefaultWorkflow } = require("../utils");
|
||||
const { start, createDefaultWorkflow, getNodeDef, checkBeforeAndAfterReload } = require("../utils");
|
||||
const lg = require("../utils/litegraph");
|
||||
|
||||
describe("group node", () => {
|
||||
@ -273,7 +273,7 @@ describe("group node", () => {
|
||||
|
||||
let reroutes = [];
|
||||
let prevNode = nodes.ckpt;
|
||||
for(let i = 0; i < 5; i++) {
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const reroute = ez.Reroute();
|
||||
prevNode.outputs[0].connectTo(reroute.inputs[0]);
|
||||
prevNode = reroute;
|
||||
@ -283,7 +283,7 @@ describe("group node", () => {
|
||||
|
||||
const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]);
|
||||
expect((await graph.toPrompt()).output).toEqual(getOutput());
|
||||
|
||||
|
||||
group.menu["Convert to nodes"].call();
|
||||
expect((await graph.toPrompt()).output).toEqual(getOutput());
|
||||
});
|
||||
@ -383,6 +383,43 @@ describe("group node", () => {
|
||||
getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id])
|
||||
);
|
||||
});
|
||||
test("groups can connect to each other via internal reroutes", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
|
||||
const latent = ez.EmptyLatentImage();
|
||||
const vae = ez.VAELoader();
|
||||
const latentReroute = ez.Reroute();
|
||||
const vaeReroute = ez.Reroute();
|
||||
|
||||
latent.outputs[0].connectTo(latentReroute.inputs[0]);
|
||||
vae.outputs[0].connectTo(vaeReroute.inputs[0]);
|
||||
|
||||
const group1 = await convertToGroup(app, graph, "test", [latentReroute, vaeReroute]);
|
||||
group1.menu.Clone.call();
|
||||
expect(app.graph._nodes).toHaveLength(4);
|
||||
const group2 = graph.find(app.graph._nodes[3]);
|
||||
expect(group2.node.type).toEqual("workflow/test");
|
||||
expect(group2.id).not.toEqual(group1.id);
|
||||
|
||||
group1.outputs.VAE.connectTo(group2.inputs.VAE);
|
||||
group1.outputs.LATENT.connectTo(group2.inputs.LATENT);
|
||||
|
||||
const decode = ez.VAEDecode(group2.outputs.LATENT, group2.outputs.VAE);
|
||||
const preview = ez.PreviewImage(decode.outputs[0]);
|
||||
|
||||
const output = {
|
||||
[latent.id]: { inputs: { width: 512, height: 512, batch_size: 1 }, class_type: "EmptyLatentImage" },
|
||||
[vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" },
|
||||
[decode.id]: { inputs: { samples: [latent.id + "", 0], vae: [vae.id + "", 0] }, class_type: "VAEDecode" },
|
||||
[preview.id]: { inputs: { images: [decode.id + "", 0] }, class_type: "PreviewImage" },
|
||||
};
|
||||
expect((await graph.toPrompt()).output).toEqual(output);
|
||||
|
||||
// Ensure missing connections dont cause errors
|
||||
group2.inputs.VAE.disconnect();
|
||||
delete output[decode.id].inputs.vae;
|
||||
expect((await graph.toPrompt()).output).toEqual(output);
|
||||
});
|
||||
test("displays generated image on group node", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
@ -642,6 +679,55 @@ describe("group node", () => {
|
||||
2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
|
||||
});
|
||||
});
|
||||
test("correctly handles widget inputs", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const upscaleMethods = (await getNodeDef("ImageScaleBy")).input.required["upscale_method"][0];
|
||||
|
||||
const image = ez.LoadImage();
|
||||
const scale1 = ez.ImageScaleBy(image.outputs[0]);
|
||||
const scale2 = ez.ImageScaleBy(image.outputs[0]);
|
||||
const preview1 = ez.PreviewImage(scale1.outputs[0]);
|
||||
const preview2 = ez.PreviewImage(scale2.outputs[0]);
|
||||
scale1.widgets.upscale_method.value = upscaleMethods[1];
|
||||
scale1.widgets.upscale_method.convertToInput();
|
||||
|
||||
const group = await convertToGroup(app, graph, "test", [scale1, scale2]);
|
||||
expect(group.inputs.length).toBe(3);
|
||||
expect(group.inputs[0].input.type).toBe("IMAGE");
|
||||
expect(group.inputs[1].input.type).toBe("IMAGE");
|
||||
expect(group.inputs[2].input.type).toBe("COMBO");
|
||||
|
||||
// Ensure links are maintained
|
||||
expect(group.inputs[0].connection?.originNode?.id).toBe(image.id);
|
||||
expect(group.inputs[1].connection?.originNode?.id).toBe(image.id);
|
||||
expect(group.inputs[2].connection).toBeFalsy();
|
||||
|
||||
// Ensure primitive gets correct type
|
||||
const primitive = ez.PrimitiveNode();
|
||||
primitive.outputs[0].connectTo(group.inputs[2]);
|
||||
expect(primitive.widgets.value.widget.options.values).toBe(upscaleMethods);
|
||||
expect(primitive.widgets.value.value).toBe(upscaleMethods[1]); // Ensure value is copied
|
||||
primitive.widgets.value.value = upscaleMethods[1];
|
||||
|
||||
await checkBeforeAndAfterReload(graph, async (r) => {
|
||||
const scale1id = r ? `${group.id}:0` : scale1.id;
|
||||
const scale2id = r ? `${group.id}:1` : scale2.id;
|
||||
// Ensure widget value is applied to prompt
|
||||
expect((await graph.toPrompt()).output).toStrictEqual({
|
||||
[image.id]: { inputs: { image: "example.png", upload: "image" }, class_type: "LoadImage" },
|
||||
[scale1id]: {
|
||||
inputs: { upscale_method: upscaleMethods[1], scale_by: 1, image: [`${image.id}`, 0] },
|
||||
class_type: "ImageScaleBy",
|
||||
},
|
||||
[scale2id]: {
|
||||
inputs: { upscale_method: "nearest-exact", scale_by: 1, image: [`${image.id}`, 0] },
|
||||
class_type: "ImageScaleBy",
|
||||
},
|
||||
[preview1.id]: { inputs: { images: [`${scale1id}`, 0] }, class_type: "PreviewImage" },
|
||||
[preview2.id]: { inputs: { images: [`${scale2id}`, 0] }, class_type: "PreviewImage" },
|
||||
});
|
||||
});
|
||||
});
|
||||
test("adds widgets in node execution order", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const scale = ez.LatentUpscale();
|
||||
@ -815,4 +901,105 @@ describe("group node", () => {
|
||||
expect(p2.widgets.control_after_generate.value).toBe("randomize");
|
||||
expect(p2.widgets.control_filter_list.value).toBe("/.+/");
|
||||
});
|
||||
test("internal reroutes work with converted inputs and merge options", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const vae = ez.VAELoader();
|
||||
const latent = ez.EmptyLatentImage();
|
||||
const decode = ez.VAEDecode(latent.outputs.LATENT, vae.outputs.VAE);
|
||||
const scale = ez.ImageScale(decode.outputs.IMAGE);
|
||||
ez.PreviewImage(scale.outputs.IMAGE);
|
||||
|
||||
const r1 = ez.Reroute();
|
||||
const r2 = ez.Reroute();
|
||||
|
||||
latent.widgets.width.value = 64;
|
||||
latent.widgets.height.value = 128;
|
||||
|
||||
latent.widgets.width.convertToInput();
|
||||
latent.widgets.height.convertToInput();
|
||||
latent.widgets.batch_size.convertToInput();
|
||||
|
||||
scale.widgets.width.convertToInput();
|
||||
scale.widgets.height.convertToInput();
|
||||
|
||||
r1.inputs[0].input.label = "hbw";
|
||||
r1.outputs[0].connectTo(latent.inputs.height);
|
||||
r1.outputs[0].connectTo(latent.inputs.batch_size);
|
||||
r1.outputs[0].connectTo(scale.inputs.width);
|
||||
|
||||
r2.inputs[0].input.label = "wh";
|
||||
r2.outputs[0].connectTo(latent.inputs.width);
|
||||
r2.outputs[0].connectTo(scale.inputs.height);
|
||||
|
||||
const group = await convertToGroup(app, graph, "test", [r1, r2, latent, decode, scale]);
|
||||
|
||||
expect(group.inputs[0].input.type).toBe("VAE");
|
||||
expect(group.inputs[1].input.type).toBe("INT");
|
||||
expect(group.inputs[2].input.type).toBe("INT");
|
||||
|
||||
const p1 = ez.PrimitiveNode();
|
||||
const p2 = ez.PrimitiveNode();
|
||||
p1.outputs[0].connectTo(group.inputs[1]);
|
||||
p2.outputs[0].connectTo(group.inputs[2]);
|
||||
|
||||
expect(p1.widgets.value.widget.options?.min).toBe(16); // width/height min
|
||||
expect(p1.widgets.value.widget.options?.max).toBe(4096); // batch max
|
||||
expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
|
||||
|
||||
expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min
|
||||
expect(p2.widgets.value.widget.options?.max).toBe(8192); // width/height max
|
||||
expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
|
||||
|
||||
expect(p1.widgets.value.value).toBe(128);
|
||||
expect(p2.widgets.value.value).toBe(64);
|
||||
|
||||
p1.widgets.value.value = 16;
|
||||
p2.widgets.value.value = 32;
|
||||
|
||||
await checkBeforeAndAfterReload(graph, async (r) => {
|
||||
const id = (v) => (r ? `${group.id}:` : "") + v;
|
||||
expect((await graph.toPrompt()).output).toStrictEqual({
|
||||
1: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" },
|
||||
[id(2)]: { inputs: { width: 32, height: 16, batch_size: 16 }, class_type: "EmptyLatentImage" },
|
||||
[id(3)]: { inputs: { samples: [id(2), 0], vae: ["1", 0] }, class_type: "VAEDecode" },
|
||||
[id(4)]: {
|
||||
inputs: { upscale_method: "nearest-exact", width: 16, height: 32, crop: "disabled", image: [id(3), 0] },
|
||||
class_type: "ImageScale",
|
||||
},
|
||||
5: { inputs: { images: [id(4), 0] }, class_type: "PreviewImage" },
|
||||
});
|
||||
});
|
||||
});
|
||||
test("converted inputs with linked widgets map values correctly on creation", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const k1 = ez.KSampler();
|
||||
const k2 = ez.KSampler();
|
||||
k1.widgets.seed.convertToInput();
|
||||
k2.widgets.seed.convertToInput();
|
||||
|
||||
const rr = ez.Reroute();
|
||||
rr.outputs[0].connectTo(k1.inputs.seed);
|
||||
rr.outputs[0].connectTo(k2.inputs.seed);
|
||||
|
||||
const group = await convertToGroup(app, graph, "test", [k1, k2, rr]);
|
||||
expect(group.widgets.steps.value).toBe(20);
|
||||
expect(group.widgets.cfg.value).toBe(8);
|
||||
expect(group.widgets.scheduler.value).toBe("normal");
|
||||
expect(group.widgets["KSampler steps"].value).toBe(20);
|
||||
expect(group.widgets["KSampler cfg"].value).toBe(8);
|
||||
expect(group.widgets["KSampler scheduler"].value).toBe("normal");
|
||||
});
|
||||
test("allow multiple of the same node type to be added", async () => {
|
||||
const { ez, graph, app } = await start();
|
||||
const nodes = [...Array(10)].map(() => ez.ImageScaleBy());
|
||||
const group = await convertToGroup(app, graph, "test", nodes);
|
||||
expect(group.inputs.length).toBe(10);
|
||||
expect(group.outputs.length).toBe(10);
|
||||
expect(group.widgets.length).toBe(20);
|
||||
expect(group.widgets.map((w) => w.widget.name)).toStrictEqual(
|
||||
[...Array(10)]
|
||||
.map((_, i) => `${i > 0 ? "ImageScaleBy " : ""}${i > 1 ? i + " " : ""}`)
|
||||
.flatMap((p) => [`${p}upscale_method`, `${p}scale_by`])
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@ -1,7 +1,13 @@
|
||||
// @ts-check
|
||||
/// <reference path="../node_modules/@types/jest/index.d.ts" />
|
||||
|
||||
const { start, makeNodeDef, checkBeforeAndAfterReload, assertNotNullOrUndefined } = require("../utils");
|
||||
const {
|
||||
start,
|
||||
makeNodeDef,
|
||||
checkBeforeAndAfterReload,
|
||||
assertNotNullOrUndefined,
|
||||
createDefaultWorkflow,
|
||||
} = require("../utils");
|
||||
const lg = require("../utils/litegraph");
|
||||
|
||||
/**
|
||||
@ -36,7 +42,7 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWi
|
||||
if (controlWidgetCount) {
|
||||
const controlWidget = primitive.widgets.control_after_generate;
|
||||
expect(controlWidget.widget.type).toBe("combo");
|
||||
if(widgetType === "combo") {
|
||||
if (widgetType === "combo") {
|
||||
const filterWidget = primitive.widgets.control_filter_list;
|
||||
expect(filterWidget.widget.type).toBe("string");
|
||||
}
|
||||
@ -308,8 +314,8 @@ describe("widget inputs", () => {
|
||||
const { ez } = await start({
|
||||
mockNodeDefs: {
|
||||
...makeNodeDef("TestNode1", {}, [["A", "B"]]),
|
||||
...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true}] }),
|
||||
...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true}] }),
|
||||
...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true }] }),
|
||||
...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true }] }),
|
||||
},
|
||||
});
|
||||
|
||||
@ -330,7 +336,7 @@ describe("widget inputs", () => {
|
||||
|
||||
const n1 = ez.TestNode1();
|
||||
n1.widgets.example.convertToInput();
|
||||
const p = ez.PrimitiveNode()
|
||||
const p = ez.PrimitiveNode();
|
||||
p.outputs[0].connectTo(n1.inputs[0]);
|
||||
|
||||
const value = p.widgets.value;
|
||||
@ -380,7 +386,7 @@ describe("widget inputs", () => {
|
||||
// Check random
|
||||
control.value = "randomize";
|
||||
filter.value = "/D/";
|
||||
for(let i = 0; i < 100; i++) {
|
||||
for (let i = 0; i < 100; i++) {
|
||||
control["afterQueued"]();
|
||||
expect(value.value === "D" || value.value === "DD").toBeTruthy();
|
||||
}
|
||||
@ -392,4 +398,160 @@ describe("widget inputs", () => {
|
||||
control["afterQueued"]();
|
||||
expect(value.value).toBe("B");
|
||||
});
|
||||
|
||||
describe("reroutes", () => {
|
||||
async function checkOutput(graph, values) {
|
||||
expect((await graph.toPrompt()).output).toStrictEqual({
|
||||
1: { inputs: { ckpt_name: "model1.safetensors" }, class_type: "CheckpointLoaderSimple" },
|
||||
2: { inputs: { text: "positive", clip: ["1", 1] }, class_type: "CLIPTextEncode" },
|
||||
3: { inputs: { text: "negative", clip: ["1", 1] }, class_type: "CLIPTextEncode" },
|
||||
4: {
|
||||
inputs: { width: values.width ?? 512, height: values.height ?? 512, batch_size: values?.batch_size ?? 1 },
|
||||
class_type: "EmptyLatentImage",
|
||||
},
|
||||
5: {
|
||||
inputs: {
|
||||
seed: 0,
|
||||
steps: 20,
|
||||
cfg: 8,
|
||||
sampler_name: "euler",
|
||||
scheduler: values?.scheduler ?? "normal",
|
||||
denoise: 1,
|
||||
model: ["1", 0],
|
||||
positive: ["2", 0],
|
||||
negative: ["3", 0],
|
||||
latent_image: ["4", 0],
|
||||
},
|
||||
class_type: "KSampler",
|
||||
},
|
||||
6: { inputs: { samples: ["5", 0], vae: ["1", 2] }, class_type: "VAEDecode" },
|
||||
7: {
|
||||
inputs: { filename_prefix: values.filename_prefix ?? "ComfyUI", images: ["6", 0] },
|
||||
class_type: "SaveImage",
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async function waitForWidget(node) {
|
||||
// widgets are created slightly after the graph is ready
|
||||
// hard to find an exact hook to get these so just wait for them to be ready
|
||||
for (let i = 0; i < 10; i++) {
|
||||
await new Promise((r) => setTimeout(r, 10));
|
||||
if (node.widgets?.value) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
it("can connect primitive via a reroute path to a widget input", async () => {
|
||||
const { ez, graph } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
|
||||
nodes.empty.widgets.width.convertToInput();
|
||||
nodes.sampler.widgets.scheduler.convertToInput();
|
||||
nodes.save.widgets.filename_prefix.convertToInput();
|
||||
|
||||
let widthReroute = ez.Reroute();
|
||||
let schedulerReroute = ez.Reroute();
|
||||
let fileReroute = ez.Reroute();
|
||||
|
||||
let widthNext = widthReroute;
|
||||
let schedulerNext = schedulerReroute;
|
||||
let fileNext = fileReroute;
|
||||
|
||||
for (let i = 0; i < 5; i++) {
|
||||
let next = ez.Reroute();
|
||||
widthNext.outputs[0].connectTo(next.inputs[0]);
|
||||
widthNext = next;
|
||||
|
||||
next = ez.Reroute();
|
||||
schedulerNext.outputs[0].connectTo(next.inputs[0]);
|
||||
schedulerNext = next;
|
||||
|
||||
next = ez.Reroute();
|
||||
fileNext.outputs[0].connectTo(next.inputs[0]);
|
||||
fileNext = next;
|
||||
}
|
||||
|
||||
widthNext.outputs[0].connectTo(nodes.empty.inputs.width);
|
||||
schedulerNext.outputs[0].connectTo(nodes.sampler.inputs.scheduler);
|
||||
fileNext.outputs[0].connectTo(nodes.save.inputs.filename_prefix);
|
||||
|
||||
let widthPrimitive = ez.PrimitiveNode();
|
||||
let schedulerPrimitive = ez.PrimitiveNode();
|
||||
let filePrimitive = ez.PrimitiveNode();
|
||||
|
||||
widthPrimitive.outputs[0].connectTo(widthReroute.inputs[0]);
|
||||
schedulerPrimitive.outputs[0].connectTo(schedulerReroute.inputs[0]);
|
||||
filePrimitive.outputs[0].connectTo(fileReroute.inputs[0]);
|
||||
expect(widthPrimitive.widgets.value.value).toBe(512);
|
||||
widthPrimitive.widgets.value.value = 1024;
|
||||
expect(schedulerPrimitive.widgets.value.value).toBe("normal");
|
||||
schedulerPrimitive.widgets.value.value = "simple";
|
||||
expect(filePrimitive.widgets.value.value).toBe("ComfyUI");
|
||||
filePrimitive.widgets.value.value = "ComfyTest";
|
||||
|
||||
await checkBeforeAndAfterReload(graph, async () => {
|
||||
widthPrimitive = graph.find(widthPrimitive);
|
||||
schedulerPrimitive = graph.find(schedulerPrimitive);
|
||||
filePrimitive = graph.find(filePrimitive);
|
||||
await waitForWidget(filePrimitive);
|
||||
expect(widthPrimitive.widgets.length).toBe(2);
|
||||
expect(schedulerPrimitive.widgets.length).toBe(3);
|
||||
expect(filePrimitive.widgets.length).toBe(1);
|
||||
|
||||
await checkOutput(graph, {
|
||||
width: 1024,
|
||||
scheduler: "simple",
|
||||
filename_prefix: "ComfyTest",
|
||||
});
|
||||
});
|
||||
});
|
||||
it("can connect primitive via a reroute path to multiple widget inputs", async () => {
|
||||
const { ez, graph } = await start();
|
||||
const nodes = createDefaultWorkflow(ez, graph);
|
||||
|
||||
nodes.empty.widgets.width.convertToInput();
|
||||
nodes.empty.widgets.height.convertToInput();
|
||||
nodes.empty.widgets.batch_size.convertToInput();
|
||||
|
||||
let reroute = ez.Reroute();
|
||||
let prevReroute = reroute;
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const next = ez.Reroute();
|
||||
prevReroute.outputs[0].connectTo(next.inputs[0]);
|
||||
prevReroute = next;
|
||||
}
|
||||
|
||||
const r1 = ez.Reroute(prevReroute.outputs[0]);
|
||||
const r2 = ez.Reroute(prevReroute.outputs[0]);
|
||||
const r3 = ez.Reroute(r2.outputs[0]);
|
||||
const r4 = ez.Reroute(r2.outputs[0]);
|
||||
|
||||
r1.outputs[0].connectTo(nodes.empty.inputs.width);
|
||||
r3.outputs[0].connectTo(nodes.empty.inputs.height);
|
||||
r4.outputs[0].connectTo(nodes.empty.inputs.batch_size);
|
||||
|
||||
let primitive = ez.PrimitiveNode();
|
||||
primitive.outputs[0].connectTo(reroute.inputs[0]);
|
||||
expect(primitive.widgets.value.value).toBe(1);
|
||||
primitive.widgets.value.value = 64;
|
||||
|
||||
await checkBeforeAndAfterReload(graph, async (r) => {
|
||||
primitive = graph.find(primitive);
|
||||
await waitForWidget(primitive);
|
||||
|
||||
// Ensure widget configs are merged
|
||||
expect(primitive.widgets.value.widget.options?.min).toBe(16); // width/height min
|
||||
expect(primitive.widgets.value.widget.options?.max).toBe(4096); // batch max
|
||||
expect(primitive.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
|
||||
|
||||
await checkOutput(graph, {
|
||||
width: 64,
|
||||
height: 64,
|
||||
batch_size: 64,
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@ -78,6 +78,14 @@ export class EzInput extends EzSlot {
|
||||
this.input = input;
|
||||
}
|
||||
|
||||
get connection() {
|
||||
const link = this.node.node.inputs?.[this.index]?.link;
|
||||
if (link == null) {
|
||||
return null;
|
||||
}
|
||||
return new EzConnection(this.node.app, this.node.app.graph.links[link]);
|
||||
}
|
||||
|
||||
disconnect() {
|
||||
this.node.node.disconnectInput(this.index);
|
||||
}
|
||||
@ -117,7 +125,7 @@ export class EzOutput extends EzSlot {
|
||||
const inp = input.input;
|
||||
const inName = inp.name || inp.label || inp.type;
|
||||
throw new Error(
|
||||
`Connecting from ${input.node.node.type}[${inName}#${input.index}] -> ${this.node.node.type}[${
|
||||
`Connecting from ${input.node.node.type}#${input.node.id}[${inName}#${input.index}] -> ${this.node.node.type}#${this.node.id}[${
|
||||
this.output.name ?? this.output.type
|
||||
}#${this.index}] failed.`
|
||||
);
|
||||
@ -179,6 +187,7 @@ export class EzWidget {
|
||||
|
||||
set value(v) {
|
||||
this.widget.value = v;
|
||||
this.widget.callback?.call?.(this.widget, v)
|
||||
}
|
||||
|
||||
get isConvertedToInput() {
|
||||
@ -319,7 +328,7 @@ export class EzGraph {
|
||||
}
|
||||
|
||||
stringify() {
|
||||
return JSON.stringify(this.app.graph.serialize(), undefined, "\t");
|
||||
return JSON.stringify(this.app.graph.serialize(), undefined);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -104,3 +104,12 @@ export function createDefaultWorkflow(ez, graph) {
|
||||
|
||||
return { ckpt, pos, neg, empty, sampler, decode, save };
|
||||
}
|
||||
|
||||
export async function getNodeDefs() {
|
||||
const { api } = require("../../web/scripts/api");
|
||||
return api.getNodeDefs();
|
||||
}
|
||||
|
||||
export async function getNodeDef(nodeId) {
|
||||
return (await getNodeDefs())[nodeId];
|
||||
}
|
||||
@ -174,6 +174,11 @@ export class GroupNodeConfig {
|
||||
node.index = i;
|
||||
this.processNode(node, seenInputs, seenOutputs);
|
||||
}
|
||||
|
||||
for (const p of this.#convertedToProcess) {
|
||||
p();
|
||||
}
|
||||
this.#convertedToProcess = null;
|
||||
await app.registerNodeDef("workflow/" + this.name, this.nodeDef);
|
||||
}
|
||||
|
||||
@ -192,7 +197,10 @@ export class GroupNodeConfig {
|
||||
if (!this.linksFrom[sourceNodeId]) {
|
||||
this.linksFrom[sourceNodeId] = {};
|
||||
}
|
||||
this.linksFrom[sourceNodeId][sourceNodeSlot] = l;
|
||||
if (!this.linksFrom[sourceNodeId][sourceNodeSlot]) {
|
||||
this.linksFrom[sourceNodeId][sourceNodeSlot] = [];
|
||||
}
|
||||
this.linksFrom[sourceNodeId][sourceNodeSlot].push(l);
|
||||
|
||||
if (!this.linksTo[targetNodeId]) {
|
||||
this.linksTo[targetNodeId] = {};
|
||||
@ -230,11 +238,11 @@ export class GroupNodeConfig {
|
||||
// Skip as its not linked
|
||||
if (!linksFrom) return;
|
||||
|
||||
let type = linksFrom["0"][5];
|
||||
let type = linksFrom["0"][0][5];
|
||||
if (type === "COMBO") {
|
||||
// Use the array items
|
||||
const source = node.outputs[0].widget.name;
|
||||
const fromTypeName = this.nodeData.nodes[linksFrom["0"][2]].type;
|
||||
const fromTypeName = this.nodeData.nodes[linksFrom["0"][0][2]].type;
|
||||
const fromType = globalDefs[fromTypeName];
|
||||
const input = fromType.input.required[source] ?? fromType.input.optional[source];
|
||||
type = input[0];
|
||||
@ -258,10 +266,33 @@ export class GroupNodeConfig {
|
||||
return null;
|
||||
}
|
||||
|
||||
let config = {};
|
||||
let rerouteType = "*";
|
||||
if (linksFrom) {
|
||||
const [, , id, slot] = linksFrom["0"];
|
||||
rerouteType = this.nodeData.nodes[id].inputs[slot].type;
|
||||
for (const [, , id, slot] of linksFrom["0"]) {
|
||||
const node = this.nodeData.nodes[id];
|
||||
const input = node.inputs[slot];
|
||||
if (rerouteType === "*") {
|
||||
rerouteType = input.type;
|
||||
}
|
||||
if (input.widget) {
|
||||
const targetDef = globalDefs[node.type];
|
||||
const targetWidget =
|
||||
targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name];
|
||||
|
||||
const widget = [targetWidget[0], config];
|
||||
const res = mergeIfValid(
|
||||
{
|
||||
widget,
|
||||
},
|
||||
targetWidget,
|
||||
false,
|
||||
null,
|
||||
widget
|
||||
);
|
||||
config = res?.customConfig ?? config;
|
||||
}
|
||||
}
|
||||
} else if (linksTo) {
|
||||
const [id, slot] = linksTo["0"];
|
||||
rerouteType = this.nodeData.nodes[id].outputs[slot].type;
|
||||
@ -282,10 +313,11 @@ export class GroupNodeConfig {
|
||||
}
|
||||
}
|
||||
|
||||
config.forceInput = true;
|
||||
return {
|
||||
input: {
|
||||
required: {
|
||||
[rerouteType]: [rerouteType, {}],
|
||||
[rerouteType]: [rerouteType, config],
|
||||
},
|
||||
},
|
||||
output: [rerouteType],
|
||||
@ -299,16 +331,17 @@ export class GroupNodeConfig {
|
||||
|
||||
getInputConfig(node, inputName, seenInputs, config, extra) {
|
||||
let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName;
|
||||
let key = name;
|
||||
let prefix = "";
|
||||
// Special handling for primitive to include the title if it is set rather than just "value"
|
||||
if ((node.type === "PrimitiveNode" && node.title) || name in seenInputs) {
|
||||
prefix = `${node.title ?? node.type} `;
|
||||
name = `${prefix}${inputName}`;
|
||||
key = name = `${prefix}${inputName}`;
|
||||
if (name in seenInputs) {
|
||||
name = `${prefix}${seenInputs[name]} ${inputName}`;
|
||||
}
|
||||
}
|
||||
seenInputs[name] = (seenInputs[name] ?? 1) + 1;
|
||||
seenInputs[key] = (seenInputs[key] ?? 1) + 1;
|
||||
|
||||
if (inputName === "seed" || inputName === "noise_seed") {
|
||||
if (!extra) extra = {};
|
||||
@ -420,10 +453,18 @@ export class GroupNodeConfig {
|
||||
defaultInput: true,
|
||||
});
|
||||
this.nodeDef.input.required[name] = config;
|
||||
this.newToOldWidgetMap[name] = { node, inputName };
|
||||
|
||||
if (!this.oldToNewWidgetMap[node.index]) {
|
||||
this.oldToNewWidgetMap[node.index] = {};
|
||||
}
|
||||
this.oldToNewWidgetMap[node.index][inputName] = name;
|
||||
|
||||
inputMap[slots.length + i] = this.inputCount++;
|
||||
}
|
||||
}
|
||||
|
||||
#convertedToProcess = [];
|
||||
processNodeInputs(node, seenInputs, inputs) {
|
||||
const inputMapping = [];
|
||||
|
||||
@ -434,7 +475,11 @@ export class GroupNodeConfig {
|
||||
const linksTo = this.linksTo[node.index] ?? {};
|
||||
const inputMap = (this.oldToNewInputMap[node.index] = {});
|
||||
this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs);
|
||||
this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs);
|
||||
|
||||
// Converted inputs have to be processed after all other nodes as they'll be at the end of the list
|
||||
this.#convertedToProcess.push(() =>
|
||||
this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs)
|
||||
);
|
||||
|
||||
return inputMapping;
|
||||
}
|
||||
@ -597,11 +642,19 @@ export class GroupNodeHandler {
|
||||
const output = this.groupData.newToOldOutputMap[link.origin_slot];
|
||||
let innerNode = this.innerNodes[output.node.index];
|
||||
let l;
|
||||
while (innerNode.type === "Reroute") {
|
||||
while (innerNode?.type === "Reroute") {
|
||||
l = innerNode.getInputLink(0);
|
||||
innerNode = innerNode.getInputNode(0);
|
||||
}
|
||||
|
||||
if (!innerNode) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (l && GroupNodeHandler.isGroupNode(innerNode)) {
|
||||
return innerNode.updateLink(l);
|
||||
}
|
||||
|
||||
link.origin_id = innerNode.id;
|
||||
link.origin_slot = l?.origin_slot ?? output.slot;
|
||||
return link;
|
||||
@ -665,6 +718,8 @@ export class GroupNodeHandler {
|
||||
top = newNode.pos[1];
|
||||
}
|
||||
|
||||
if (!newNode.widgets) continue;
|
||||
|
||||
const map = this.groupData.oldToNewWidgetMap[innerNode.index];
|
||||
if (map) {
|
||||
const widgets = Object.keys(map);
|
||||
@ -721,7 +776,7 @@ export class GroupNodeHandler {
|
||||
}
|
||||
};
|
||||
|
||||
const reconnectOutputs = () => {
|
||||
const reconnectOutputs = (selectedIds) => {
|
||||
for (let groupOutputId = 0; groupOutputId < node.outputs?.length; groupOutputId++) {
|
||||
const output = node.outputs[groupOutputId];
|
||||
if (!output.links) continue;
|
||||
@ -861,7 +916,7 @@ export class GroupNodeHandler {
|
||||
if (innerNode.type === "PrimitiveNode") {
|
||||
innerNode.primitiveValue = newValue;
|
||||
const primitiveLinked = this.groupData.primitiveToWidget[old.node.index];
|
||||
for (const linked of primitiveLinked) {
|
||||
for (const linked of primitiveLinked ?? []) {
|
||||
const node = this.innerNodes[linked.nodeId];
|
||||
const widget = node.widgets.find((w) => w.name === linked.inputName);
|
||||
|
||||
@ -870,6 +925,18 @@ export class GroupNodeHandler {
|
||||
}
|
||||
}
|
||||
continue;
|
||||
} else if (innerNode.type === "Reroute") {
|
||||
const rerouteLinks = this.groupData.linksFrom[old.node.index];
|
||||
for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) {
|
||||
const node = this.innerNodes[targetNodeId];
|
||||
const input = node.inputs[targetSlot];
|
||||
if (input.widget) {
|
||||
const widget = node.widgets?.find((w) => w.name === input.widget.name);
|
||||
if (widget) {
|
||||
widget.value = newValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const widget = innerNode.widgets?.find((w) => w.name === old.inputName);
|
||||
@ -897,33 +964,58 @@ export class GroupNodeHandler {
|
||||
this.node.widgets[targetWidgetIndex + i].value = primitiveNode.widgets[i].value;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
populateReroute(node, nodeId, map) {
|
||||
if (node.type !== "Reroute") return;
|
||||
|
||||
const link = this.groupData.linksFrom[nodeId]?.[0]?.[0];
|
||||
if (!link) return;
|
||||
const [, , targetNodeId, targetNodeSlot] = link;
|
||||
const targetNode = this.groupData.nodeData.nodes[targetNodeId];
|
||||
const inputs = targetNode.inputs;
|
||||
const targetWidget = inputs?.[targetNodeSlot].widget;
|
||||
if (!targetWidget) return;
|
||||
|
||||
const offset = inputs.length - (targetNode.widgets_values?.length ?? 0);
|
||||
const v = targetNode.widgets_values?.[targetNodeSlot - offset];
|
||||
if (v == null) return;
|
||||
|
||||
const widgetName = Object.values(map)[0];
|
||||
const widget = this.node.widgets.find(w => w.name === widgetName);
|
||||
if(widget) {
|
||||
widget.value = v;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
populateWidgets() {
|
||||
if (!this.node.widgets) return;
|
||||
|
||||
for (let nodeId = 0; nodeId < this.groupData.nodeData.nodes.length; nodeId++) {
|
||||
const node = this.groupData.nodeData.nodes[nodeId];
|
||||
|
||||
if (!node.widgets_values?.length) continue;
|
||||
|
||||
const map = this.groupData.oldToNewWidgetMap[nodeId];
|
||||
const map = this.groupData.oldToNewWidgetMap[nodeId] ?? {};
|
||||
const widgets = Object.keys(map);
|
||||
|
||||
if (!node.widgets_values?.length) {
|
||||
// special handling for populating values into reroutes
|
||||
// this allows primitives connect to them to pick up the correct value
|
||||
this.populateReroute(node, nodeId, map);
|
||||
continue;
|
||||
}
|
||||
|
||||
let linkedShift = 0;
|
||||
for (let i = 0; i < widgets.length; i++) {
|
||||
const oldName = widgets[i];
|
||||
const newName = map[oldName];
|
||||
const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName);
|
||||
const mainWidget = this.node.widgets[widgetIndex];
|
||||
if (!newName) {
|
||||
// New name will be null if its a converted widget
|
||||
this.populatePrimitive(node, nodeId, oldName, i, linkedShift);
|
||||
|
||||
if (this.populatePrimitive(node, nodeId, oldName, i, linkedShift) || widgetIndex === -1) {
|
||||
// Find the inner widget and shift by the number of linked widgets as they will have been removed too
|
||||
const innerWidget = this.innerNodes[nodeId].widgets?.find((w) => w.name === oldName);
|
||||
linkedShift += innerWidget.linkedWidgets?.length ?? 0;
|
||||
continue;
|
||||
linkedShift += innerWidget?.linkedWidgets?.length ?? 0;
|
||||
}
|
||||
|
||||
if (widgetIndex === -1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -33,6 +33,18 @@ function loadedImageToBlob(image) {
|
||||
return blob;
|
||||
}
|
||||
|
||||
function loadImage(imagePath) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const image = new Image();
|
||||
|
||||
image.onload = function() {
|
||||
resolve(image);
|
||||
};
|
||||
|
||||
image.src = imagePath;
|
||||
});
|
||||
}
|
||||
|
||||
async function uploadMask(filepath, formData) {
|
||||
await api.fetchApi('/upload/mask', {
|
||||
method: 'POST',
|
||||
@ -50,25 +62,25 @@ async function uploadMask(filepath, formData) {
|
||||
ClipspaceDialog.invalidatePreview();
|
||||
}
|
||||
|
||||
function prepareRGB(image, backupCanvas, backupCtx) {
|
||||
function prepare_mask(image, maskCanvas, maskCtx) {
|
||||
// paste mask data into alpha channel
|
||||
backupCtx.drawImage(image, 0, 0, backupCanvas.width, backupCanvas.height);
|
||||
const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height);
|
||||
maskCtx.drawImage(image, 0, 0, maskCanvas.width, maskCanvas.height);
|
||||
const maskData = maskCtx.getImageData(0, 0, maskCanvas.width, maskCanvas.height);
|
||||
|
||||
// refine mask image
|
||||
for (let i = 0; i < backupData.data.length; i += 4) {
|
||||
if(backupData.data[i+3] == 255)
|
||||
backupData.data[i+3] = 0;
|
||||
// invert mask
|
||||
for (let i = 0; i < maskData.data.length; i += 4) {
|
||||
if(maskData.data[i+3] == 255)
|
||||
maskData.data[i+3] = 0;
|
||||
else
|
||||
backupData.data[i+3] = 255;
|
||||
maskData.data[i+3] = 255;
|
||||
|
||||
backupData.data[i] = 0;
|
||||
backupData.data[i+1] = 0;
|
||||
backupData.data[i+2] = 0;
|
||||
maskData.data[i] = 0;
|
||||
maskData.data[i+1] = 0;
|
||||
maskData.data[i+2] = 0;
|
||||
}
|
||||
|
||||
backupCtx.globalCompositeOperation = 'source-over';
|
||||
backupCtx.putImageData(backupData, 0, 0);
|
||||
maskCtx.globalCompositeOperation = 'source-over';
|
||||
maskCtx.putImageData(maskData, 0, 0);
|
||||
}
|
||||
|
||||
class MaskEditorDialog extends ComfyDialog {
|
||||
@ -155,10 +167,6 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
|
||||
// If it is specified as relative, using it only as a hidden placeholder for padding is recommended
|
||||
// to prevent anomalies where it exceeds a certain size and goes outside of the window.
|
||||
var placeholder = document.createElement("div");
|
||||
placeholder.style.position = "relative";
|
||||
placeholder.style.height = "50px";
|
||||
|
||||
var bottom_panel = document.createElement("div");
|
||||
bottom_panel.style.position = "absolute";
|
||||
bottom_panel.style.bottom = "0px";
|
||||
@ -180,18 +188,16 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
this.brush = brush;
|
||||
this.element.appendChild(imgCanvas);
|
||||
this.element.appendChild(maskCanvas);
|
||||
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
|
||||
this.element.appendChild(bottom_panel);
|
||||
document.body.appendChild(brush);
|
||||
|
||||
var brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => {
|
||||
this.brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => {
|
||||
self.brush_size = event.target.value;
|
||||
self.updateBrushPreview(self, null, null);
|
||||
});
|
||||
var clearButton = this.createLeftButton("Clear",
|
||||
() => {
|
||||
self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height);
|
||||
self.backupCtx.clearRect(0, 0, self.backupCanvas.width, self.backupCanvas.height);
|
||||
});
|
||||
var cancelButton = this.createRightButton("Cancel", () => {
|
||||
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
|
||||
@ -207,40 +213,42 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
|
||||
this.element.appendChild(imgCanvas);
|
||||
this.element.appendChild(maskCanvas);
|
||||
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
|
||||
this.element.appendChild(bottom_panel);
|
||||
|
||||
bottom_panel.appendChild(clearButton);
|
||||
bottom_panel.appendChild(this.saveButton);
|
||||
bottom_panel.appendChild(cancelButton);
|
||||
bottom_panel.appendChild(brush_size_slider);
|
||||
bottom_panel.appendChild(this.brush_size_slider);
|
||||
|
||||
imgCanvas.style.position = "absolute";
|
||||
maskCanvas.style.position = "absolute";
|
||||
|
||||
imgCanvas.style.position = "relative";
|
||||
imgCanvas.style.top = "200";
|
||||
imgCanvas.style.left = "0";
|
||||
|
||||
maskCanvas.style.position = "absolute";
|
||||
maskCanvas.style.top = imgCanvas.style.top;
|
||||
maskCanvas.style.left = imgCanvas.style.left;
|
||||
}
|
||||
|
||||
show() {
|
||||
async show() {
|
||||
this.zoom_ratio = 1.0;
|
||||
this.pan_x = 0;
|
||||
this.pan_y = 0;
|
||||
|
||||
if(!this.is_layout_created) {
|
||||
// layout
|
||||
const imgCanvas = document.createElement('canvas');
|
||||
const maskCanvas = document.createElement('canvas');
|
||||
const backupCanvas = document.createElement('canvas');
|
||||
|
||||
imgCanvas.id = "imageCanvas";
|
||||
maskCanvas.id = "maskCanvas";
|
||||
backupCanvas.id = "backupCanvas";
|
||||
|
||||
this.setlayout(imgCanvas, maskCanvas);
|
||||
|
||||
// prepare content
|
||||
this.imgCanvas = imgCanvas;
|
||||
this.maskCanvas = maskCanvas;
|
||||
this.backupCanvas = backupCanvas;
|
||||
this.maskCtx = maskCanvas.getContext('2d');
|
||||
this.backupCtx = backupCanvas.getContext('2d');
|
||||
this.maskCtx = maskCanvas.getContext('2d', {willReadFrequently: true });
|
||||
|
||||
this.setEventHandler(maskCanvas);
|
||||
|
||||
@ -252,6 +260,8 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
mutations.forEach(function(mutation) {
|
||||
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
|
||||
if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') {
|
||||
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
|
||||
self.brush.style.display = "none";
|
||||
ComfyApp.onClipspaceEditorClosed();
|
||||
}
|
||||
|
||||
@ -264,7 +274,8 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
observer.observe(this.element, config);
|
||||
}
|
||||
|
||||
this.setImages(this.imgCanvas, this.backupCanvas);
|
||||
// The keydown event needs to be reconfigured when closing the dialog as it gets removed.
|
||||
document.addEventListener('keydown', MaskEditorDialog.handleKeyDown);
|
||||
|
||||
if(ComfyApp.clipspace_return_node) {
|
||||
this.saveButton.innerText = "Save to node";
|
||||
@ -275,97 +286,157 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
this.saveButton.disabled = false;
|
||||
|
||||
this.element.style.display = "block";
|
||||
this.element.style.width = "85%";
|
||||
this.element.style.margin = "0 7.5%";
|
||||
this.element.style.height = "100vh";
|
||||
this.element.style.top = "50%";
|
||||
this.element.style.left = "42%";
|
||||
this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority.
|
||||
|
||||
await this.setImages(this.imgCanvas);
|
||||
|
||||
this.is_visible = true;
|
||||
}
|
||||
|
||||
isOpened() {
|
||||
return this.element.style.display == "block";
|
||||
}
|
||||
|
||||
setImages(imgCanvas, backupCanvas) {
|
||||
const imgCtx = imgCanvas.getContext('2d');
|
||||
const backupCtx = backupCanvas.getContext('2d');
|
||||
invalidateCanvas(orig_image, mask_image) {
|
||||
this.imgCanvas.width = orig_image.width;
|
||||
this.imgCanvas.height = orig_image.height;
|
||||
|
||||
this.maskCanvas.width = orig_image.width;
|
||||
this.maskCanvas.height = orig_image.height;
|
||||
|
||||
let imgCtx = this.imgCanvas.getContext('2d', {willReadFrequently: true });
|
||||
let maskCtx = this.maskCanvas.getContext('2d', {willReadFrequently: true });
|
||||
|
||||
imgCtx.drawImage(orig_image, 0, 0, orig_image.width, orig_image.height);
|
||||
prepare_mask(mask_image, this.maskCanvas, maskCtx);
|
||||
}
|
||||
|
||||
async setImages(imgCanvas) {
|
||||
let self = this;
|
||||
|
||||
const imgCtx = imgCanvas.getContext('2d', {willReadFrequently: true });
|
||||
const maskCtx = this.maskCtx;
|
||||
const maskCanvas = this.maskCanvas;
|
||||
|
||||
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
|
||||
imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height);
|
||||
maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height);
|
||||
|
||||
// image load
|
||||
const orig_image = new Image();
|
||||
window.addEventListener("resize", () => {
|
||||
// repositioning
|
||||
imgCanvas.width = window.innerWidth - 250;
|
||||
imgCanvas.height = window.innerHeight - 200;
|
||||
|
||||
// redraw image
|
||||
let drawWidth = orig_image.width;
|
||||
let drawHeight = orig_image.height;
|
||||
if (orig_image.width > imgCanvas.width) {
|
||||
drawWidth = imgCanvas.width;
|
||||
drawHeight = (drawWidth / orig_image.width) * orig_image.height;
|
||||
}
|
||||
|
||||
if (drawHeight > imgCanvas.height) {
|
||||
drawHeight = imgCanvas.height;
|
||||
drawWidth = (drawHeight / orig_image.height) * orig_image.width;
|
||||
}
|
||||
|
||||
imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight);
|
||||
|
||||
// update mask
|
||||
maskCanvas.width = drawWidth;
|
||||
maskCanvas.height = drawHeight;
|
||||
maskCanvas.style.top = imgCanvas.offsetTop + "px";
|
||||
maskCanvas.style.left = imgCanvas.offsetLeft + "px";
|
||||
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
|
||||
maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height);
|
||||
});
|
||||
|
||||
const filepath = ComfyApp.clipspace.images;
|
||||
|
||||
const touched_image = new Image();
|
||||
|
||||
touched_image.onload = function() {
|
||||
backupCanvas.width = touched_image.width;
|
||||
backupCanvas.height = touched_image.height;
|
||||
|
||||
prepareRGB(touched_image, backupCanvas, backupCtx);
|
||||
};
|
||||
|
||||
const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src)
|
||||
alpha_url.searchParams.delete('channel');
|
||||
alpha_url.searchParams.delete('preview');
|
||||
alpha_url.searchParams.set('channel', 'a');
|
||||
touched_image.src = alpha_url;
|
||||
let mask_image = await loadImage(alpha_url);
|
||||
|
||||
// original image load
|
||||
orig_image.onload = function() {
|
||||
window.dispatchEvent(new Event('resize'));
|
||||
};
|
||||
|
||||
const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src);
|
||||
rgb_url.searchParams.delete('channel');
|
||||
rgb_url.searchParams.set('channel', 'rgb');
|
||||
orig_image.src = rgb_url;
|
||||
this.image = orig_image;
|
||||
this.image = new Image();
|
||||
this.image.onload = function() {
|
||||
maskCanvas.width = self.image.width;
|
||||
maskCanvas.height = self.image.height;
|
||||
|
||||
self.invalidateCanvas(self.image, mask_image);
|
||||
self.initializeCanvasPanZoom();
|
||||
};
|
||||
this.image.src = rgb_url;
|
||||
}
|
||||
|
||||
setEventHandler(maskCanvas) {
|
||||
maskCanvas.addEventListener("contextmenu", (event) => {
|
||||
event.preventDefault();
|
||||
});
|
||||
initializeCanvasPanZoom() {
|
||||
// set initialize
|
||||
let drawWidth = this.image.width;
|
||||
let drawHeight = this.image.height;
|
||||
|
||||
let width = this.element.clientWidth;
|
||||
let height = this.element.clientHeight;
|
||||
|
||||
if (this.image.width > width) {
|
||||
drawWidth = width;
|
||||
drawHeight = (drawWidth / this.image.width) * this.image.height;
|
||||
}
|
||||
|
||||
if (drawHeight > height) {
|
||||
drawHeight = height;
|
||||
drawWidth = (drawHeight / this.image.height) * this.image.width;
|
||||
}
|
||||
|
||||
this.zoom_ratio = drawWidth/this.image.width;
|
||||
|
||||
const canvasX = (width - drawWidth) / 2;
|
||||
const canvasY = (height - drawHeight) / 2;
|
||||
this.pan_x = canvasX;
|
||||
this.pan_y = canvasY;
|
||||
|
||||
this.invalidatePanZoom();
|
||||
}
|
||||
|
||||
|
||||
invalidatePanZoom() {
|
||||
let raw_width = this.image.width * this.zoom_ratio;
|
||||
let raw_height = this.image.height * this.zoom_ratio;
|
||||
|
||||
if(this.pan_x + raw_width < 10) {
|
||||
this.pan_x = 10 - raw_width;
|
||||
}
|
||||
|
||||
if(this.pan_y + raw_height < 10) {
|
||||
this.pan_y = 10 - raw_height;
|
||||
}
|
||||
|
||||
let width = `${raw_width}px`;
|
||||
let height = `${raw_height}px`;
|
||||
|
||||
let left = `${this.pan_x}px`;
|
||||
let top = `${this.pan_y}px`;
|
||||
|
||||
this.maskCanvas.style.width = width;
|
||||
this.maskCanvas.style.height = height;
|
||||
this.maskCanvas.style.left = left;
|
||||
this.maskCanvas.style.top = top;
|
||||
|
||||
this.imgCanvas.style.width = width;
|
||||
this.imgCanvas.style.height = height;
|
||||
this.imgCanvas.style.left = left;
|
||||
this.imgCanvas.style.top = top;
|
||||
}
|
||||
|
||||
|
||||
setEventHandler(maskCanvas) {
|
||||
const self = this;
|
||||
maskCanvas.addEventListener('wheel', (event) => this.handleWheelEvent(self,event));
|
||||
maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event));
|
||||
document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp);
|
||||
maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event));
|
||||
maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event));
|
||||
maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; });
|
||||
maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; });
|
||||
document.addEventListener('keydown', MaskEditorDialog.handleKeyDown);
|
||||
|
||||
if(!this.handler_registered) {
|
||||
maskCanvas.addEventListener("contextmenu", (event) => {
|
||||
event.preventDefault();
|
||||
});
|
||||
|
||||
this.element.addEventListener('wheel', (event) => this.handleWheelEvent(self,event));
|
||||
this.element.addEventListener('pointermove', (event) => this.pointMoveEvent(self,event));
|
||||
this.element.addEventListener('touchmove', (event) => this.pointMoveEvent(self,event));
|
||||
|
||||
this.element.addEventListener('dragstart', (event) => {
|
||||
if(event.ctrlKey) {
|
||||
event.preventDefault();
|
||||
}
|
||||
});
|
||||
|
||||
maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event));
|
||||
maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event));
|
||||
maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event));
|
||||
maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; });
|
||||
maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; });
|
||||
|
||||
document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp);
|
||||
|
||||
this.handler_registered = true;
|
||||
}
|
||||
}
|
||||
|
||||
brush_size = 10;
|
||||
@ -378,8 +449,10 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
const self = MaskEditorDialog.instance;
|
||||
if (event.key === ']') {
|
||||
self.brush_size = Math.min(self.brush_size+2, 100);
|
||||
self.brush_slider_input.value = self.brush_size;
|
||||
} else if (event.key === '[') {
|
||||
self.brush_size = Math.max(self.brush_size-2, 1);
|
||||
self.brush_slider_input.value = self.brush_size;
|
||||
} else if(event.key === 'Enter') {
|
||||
self.save();
|
||||
}
|
||||
@ -389,6 +462,10 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
|
||||
static handlePointerUp(event) {
|
||||
event.preventDefault();
|
||||
|
||||
this.mousedown_x = null;
|
||||
this.mousedown_y = null;
|
||||
|
||||
MaskEditorDialog.instance.drawing_mode = false;
|
||||
}
|
||||
|
||||
@ -398,24 +475,70 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
var centerX = self.cursorX;
|
||||
var centerY = self.cursorY;
|
||||
|
||||
brush.style.width = self.brush_size * 2 + "px";
|
||||
brush.style.height = self.brush_size * 2 + "px";
|
||||
brush.style.left = (centerX - self.brush_size) + "px";
|
||||
brush.style.top = (centerY - self.brush_size) + "px";
|
||||
brush.style.width = self.brush_size * 2 * this.zoom_ratio + "px";
|
||||
brush.style.height = self.brush_size * 2 * this.zoom_ratio + "px";
|
||||
brush.style.left = (centerX - self.brush_size * this.zoom_ratio) + "px";
|
||||
brush.style.top = (centerY - self.brush_size * this.zoom_ratio) + "px";
|
||||
}
|
||||
|
||||
handleWheelEvent(self, event) {
|
||||
if(event.deltaY < 0)
|
||||
self.brush_size = Math.min(self.brush_size+2, 100);
|
||||
else
|
||||
self.brush_size = Math.max(self.brush_size-2, 1);
|
||||
event.preventDefault();
|
||||
|
||||
self.brush_slider_input.value = self.brush_size;
|
||||
if(event.ctrlKey) {
|
||||
// zoom canvas
|
||||
if(event.deltaY < 0) {
|
||||
this.zoom_ratio = Math.min(10.0, this.zoom_ratio+0.2);
|
||||
}
|
||||
else {
|
||||
this.zoom_ratio = Math.max(0.2, this.zoom_ratio-0.2);
|
||||
}
|
||||
|
||||
this.invalidatePanZoom();
|
||||
}
|
||||
else {
|
||||
// adjust brush size
|
||||
if(event.deltaY < 0)
|
||||
this.brush_size = Math.min(this.brush_size+2, 100);
|
||||
else
|
||||
this.brush_size = Math.max(this.brush_size-2, 1);
|
||||
|
||||
this.brush_slider_input.value = this.brush_size;
|
||||
|
||||
this.updateBrushPreview(this);
|
||||
}
|
||||
}
|
||||
|
||||
pointMoveEvent(self, event) {
|
||||
this.cursorX = event.pageX;
|
||||
this.cursorY = event.pageY;
|
||||
|
||||
self.updateBrushPreview(self);
|
||||
|
||||
if(event.ctrlKey) {
|
||||
event.preventDefault();
|
||||
self.pan_move(self, event);
|
||||
}
|
||||
}
|
||||
|
||||
pan_move(self, event) {
|
||||
if(event.buttons == 1) {
|
||||
if(this.mousedown_x) {
|
||||
let deltaX = this.mousedown_x - event.clientX;
|
||||
let deltaY = this.mousedown_y - event.clientY;
|
||||
|
||||
self.pan_x = this.mousedown_pan_x - deltaX;
|
||||
self.pan_y = this.mousedown_pan_y - deltaY;
|
||||
|
||||
self.invalidatePanZoom();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
draw_move(self, event) {
|
||||
if(event.ctrlKey) {
|
||||
return;
|
||||
}
|
||||
|
||||
event.preventDefault();
|
||||
|
||||
this.cursorX = event.pageX;
|
||||
@ -439,6 +562,9 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
y = event.targetTouches[0].clientY - maskRect.top;
|
||||
}
|
||||
|
||||
x /= self.zoom_ratio;
|
||||
y /= self.zoom_ratio;
|
||||
|
||||
var brush_size = this.brush_size;
|
||||
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
||||
brush_size *= event.pressure;
|
||||
@ -489,8 +615,8 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
}
|
||||
else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) {
|
||||
const maskRect = self.maskCanvas.getBoundingClientRect();
|
||||
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
|
||||
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
|
||||
const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio;
|
||||
const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio;
|
||||
|
||||
var brush_size = this.brush_size;
|
||||
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
||||
@ -540,6 +666,17 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
}
|
||||
|
||||
handlePointerDown(self, event) {
|
||||
if(event.ctrlKey) {
|
||||
if (event.buttons == 1) {
|
||||
this.mousedown_x = event.clientX;
|
||||
this.mousedown_y = event.clientY;
|
||||
|
||||
this.mousedown_pan_x = this.pan_x;
|
||||
this.mousedown_pan_y = this.pan_y;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
var brush_size = this.brush_size;
|
||||
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
||||
brush_size *= event.pressure;
|
||||
@ -551,8 +688,8 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
|
||||
event.preventDefault();
|
||||
const maskRect = self.maskCanvas.getBoundingClientRect();
|
||||
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
|
||||
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
|
||||
const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio;
|
||||
const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio;
|
||||
|
||||
self.maskCtx.beginPath();
|
||||
if (event.button == 0) {
|
||||
@ -570,15 +707,18 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
}
|
||||
|
||||
async save() {
|
||||
const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true});
|
||||
const backupCanvas = document.createElement('canvas');
|
||||
const backupCtx = backupCanvas.getContext('2d', {willReadFrequently:true});
|
||||
backupCanvas.width = this.image.width;
|
||||
backupCanvas.height = this.image.height;
|
||||
|
||||
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
|
||||
backupCtx.clearRect(0,0, backupCanvas.width, backupCanvas.height);
|
||||
backupCtx.drawImage(this.maskCanvas,
|
||||
0, 0, this.maskCanvas.width, this.maskCanvas.height,
|
||||
0, 0, this.backupCanvas.width, this.backupCanvas.height);
|
||||
0, 0, backupCanvas.width, backupCanvas.height);
|
||||
|
||||
// paste mask data into alpha channel
|
||||
const backupData = backupCtx.getImageData(0, 0, this.backupCanvas.width, this.backupCanvas.height);
|
||||
const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height);
|
||||
|
||||
// refine mask image
|
||||
for (let i = 0; i < backupData.data.length; i += 4) {
|
||||
@ -615,7 +755,7 @@ class MaskEditorDialog extends ComfyDialog {
|
||||
ComfyApp.clipspace.widgets[index].value = item;
|
||||
}
|
||||
|
||||
const dataURL = this.backupCanvas.toDataURL();
|
||||
const dataURL = backupCanvas.toDataURL();
|
||||
const blob = dataURLToBlob(dataURL);
|
||||
|
||||
let original_url = new URL(this.image.src);
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { mergeIfValid, getWidgetConfig, setWidgetConfig } from "./widgetInputs.js";
|
||||
|
||||
// Node that allows you to redirect connections for cleaner graphs
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.RerouteNode",
|
||||
registerCustomNodes() {
|
||||
registerCustomNodes(app) {
|
||||
class RerouteNode {
|
||||
constructor() {
|
||||
if (!this.properties) {
|
||||
@ -16,6 +17,12 @@ app.registerExtension({
|
||||
this.addInput("", "*");
|
||||
this.addOutput(this.properties.showOutputText ? "*" : "", "*");
|
||||
|
||||
this.onAfterGraphConfigured = function () {
|
||||
requestAnimationFrame(() => {
|
||||
this.onConnectionsChange(LiteGraph.INPUT, null, true, null);
|
||||
});
|
||||
};
|
||||
|
||||
this.onConnectionsChange = function (type, index, connected, link_info) {
|
||||
this.applyOrientation();
|
||||
|
||||
@ -47,6 +54,7 @@ app.registerExtension({
|
||||
const linkId = currentNode.inputs[0].link;
|
||||
if (linkId !== null) {
|
||||
const link = app.graph.links[linkId];
|
||||
if (!link) return;
|
||||
const node = app.graph.getNodeById(link.origin_id);
|
||||
const type = node.constructor.type;
|
||||
if (type === "Reroute") {
|
||||
@ -54,8 +62,7 @@ app.registerExtension({
|
||||
// We've found a circle
|
||||
currentNode.disconnectInput(link.target_slot);
|
||||
currentNode = null;
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
// Move the previous node
|
||||
currentNode = node;
|
||||
}
|
||||
@ -94,8 +101,11 @@ app.registerExtension({
|
||||
updateNodes.push(node);
|
||||
} else {
|
||||
// We've found an output
|
||||
const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null;
|
||||
if (inputType && nodeOutType !== inputType) {
|
||||
const nodeOutType =
|
||||
node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type
|
||||
? node.inputs[link.target_slot].type
|
||||
: null;
|
||||
if (inputType && inputType !== "*" && nodeOutType !== inputType) {
|
||||
// The output doesnt match our input so disconnect it
|
||||
node.disconnectInput(link.target_slot);
|
||||
} else {
|
||||
@ -111,6 +121,9 @@ app.registerExtension({
|
||||
const displayType = inputType || outputType || "*";
|
||||
const color = LGraphCanvas.link_type_colors[displayType];
|
||||
|
||||
let widgetConfig;
|
||||
let targetWidget;
|
||||
let widgetType;
|
||||
// Update the types of each node
|
||||
for (const node of updateNodes) {
|
||||
// If we dont have an input type we are always wildcard but we'll show the output type
|
||||
@ -125,10 +138,38 @@ app.registerExtension({
|
||||
const link = app.graph.links[l];
|
||||
if (link) {
|
||||
link.color = color;
|
||||
|
||||
if (app.configuringGraph) continue;
|
||||
const targetNode = app.graph.getNodeById(link.target_id);
|
||||
const targetInput = targetNode.inputs?.[link.target_slot];
|
||||
if (targetInput?.widget) {
|
||||
const config = getWidgetConfig(targetInput);
|
||||
if (!widgetConfig) {
|
||||
widgetConfig = config[1] ?? {};
|
||||
widgetType = config[0];
|
||||
}
|
||||
if (!targetWidget) {
|
||||
targetWidget = targetNode.widgets?.find((w) => w.name === targetInput.widget.name);
|
||||
}
|
||||
|
||||
const merged = mergeIfValid(targetInput, [config[0], widgetConfig]);
|
||||
if (merged.customConfig) {
|
||||
widgetConfig = merged.customConfig;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const node of updateNodes) {
|
||||
if (widgetConfig && outputType) {
|
||||
node.inputs[0].widget = { name: "value" };
|
||||
setWidgetConfig(node.inputs[0], [widgetType ?? displayType, widgetConfig], targetWidget);
|
||||
} else {
|
||||
setWidgetConfig(node.inputs[0], null);
|
||||
}
|
||||
}
|
||||
|
||||
if (inputNode) {
|
||||
const link = app.graph.links[inputNode.inputs[0].link];
|
||||
if (link) {
|
||||
@ -173,8 +214,8 @@ app.registerExtension({
|
||||
},
|
||||
{
|
||||
// naming is inverted with respect to LiteGraphNode.horizontal
|
||||
// LiteGraphNode.horizontal == true means that
|
||||
// each slot in the inputs and outputs are layed out horizontally,
|
||||
// LiteGraphNode.horizontal == true means that
|
||||
// each slot in the inputs and outputs are layed out horizontally,
|
||||
// which is the opposite of the visual orientation of the inputs and outputs as a node
|
||||
content: "Set " + (this.properties.horizontal ? "Horizontal" : "Vertical"),
|
||||
callback: () => {
|
||||
@ -187,7 +228,7 @@ app.registerExtension({
|
||||
applyOrientation() {
|
||||
this.horizontal = this.properties.horizontal;
|
||||
if (this.horizontal) {
|
||||
// we correct the input position, because LiteGraphNode.horizontal
|
||||
// we correct the input position, because LiteGraphNode.horizontal
|
||||
// doesn't account for title presence
|
||||
// which reroute nodes don't have
|
||||
this.inputs[0].pos = [this.size[0] / 2, 0];
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { app } from "../../scripts/app.js";
|
||||
|
||||
import { applyTextReplacements } from "../../scripts/utils.js";
|
||||
// Use widget values and dates in output filenames
|
||||
|
||||
app.registerExtension({
|
||||
@ -7,84 +7,19 @@ app.registerExtension({
|
||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
||||
if (nodeData.name === "SaveImage") {
|
||||
const onNodeCreated = nodeType.prototype.onNodeCreated;
|
||||
|
||||
// Simple date formatter
|
||||
const parts = {
|
||||
d: (d) => d.getDate(),
|
||||
M: (d) => d.getMonth() + 1,
|
||||
h: (d) => d.getHours(),
|
||||
m: (d) => d.getMinutes(),
|
||||
s: (d) => d.getSeconds(),
|
||||
};
|
||||
const format =
|
||||
Object.keys(parts)
|
||||
.map((k) => k + k + "?")
|
||||
.join("|") + "|yyy?y?";
|
||||
|
||||
function formatDate(text, date) {
|
||||
return text.replace(new RegExp(format, "g"), function (text) {
|
||||
if (text === "yy") return (date.getFullYear() + "").substring(2);
|
||||
if (text === "yyyy") return date.getFullYear();
|
||||
if (text[0] in parts) {
|
||||
const p = parts[text[0]](date);
|
||||
return (p + "").padStart(text.length, "0");
|
||||
}
|
||||
return text;
|
||||
});
|
||||
}
|
||||
|
||||
// When the SaveImage node is created we want to override the serialization of the output name widget to run our S&R
|
||||
// When the SaveImage node is created we want to override the serialization of the output name widget to run our S&R
|
||||
nodeType.prototype.onNodeCreated = function () {
|
||||
const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
|
||||
|
||||
const widget = this.widgets.find((w) => w.name === "filename_prefix");
|
||||
widget.serializeValue = () => {
|
||||
return widget.value.replace(/%([^%]+)%/g, function (match, text) {
|
||||
const split = text.split(".");
|
||||
if (split.length !== 2) {
|
||||
// Special handling for dates
|
||||
if (split[0].startsWith("date:")) {
|
||||
return formatDate(split[0].substring(5), new Date());
|
||||
}
|
||||
|
||||
if (text !== "width" && text !== "height") {
|
||||
// Dont warn on standard replacements
|
||||
console.warn("Invalid replacement pattern", text);
|
||||
}
|
||||
return match;
|
||||
}
|
||||
|
||||
// Find node with matching S&R property name
|
||||
let nodes = app.graph._nodes.filter((n) => n.properties?.["Node name for S&R"] === split[0]);
|
||||
// If we cant, see if there is a node with that title
|
||||
if (!nodes.length) {
|
||||
nodes = app.graph._nodes.filter((n) => n.title === split[0]);
|
||||
}
|
||||
if (!nodes.length) {
|
||||
console.warn("Unable to find node", split[0]);
|
||||
return match;
|
||||
}
|
||||
|
||||
if (nodes.length > 1) {
|
||||
console.warn("Multiple nodes matched", split[0], "using first match");
|
||||
}
|
||||
|
||||
const node = nodes[0];
|
||||
|
||||
const widget = node.widgets?.find((w) => w.name === split[1]);
|
||||
if (!widget) {
|
||||
console.warn("Unable to find widget", split[1], "on node", split[0], node);
|
||||
return match;
|
||||
}
|
||||
|
||||
return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_");
|
||||
});
|
||||
return applyTextReplacements(app, widget.value);
|
||||
};
|
||||
|
||||
return r;
|
||||
};
|
||||
} else {
|
||||
// When any other node is created add a property to alias the node
|
||||
// When any other node is created add a property to alias the node
|
||||
const onNodeCreated = nodeType.prototype.onNodeCreated;
|
||||
nodeType.prototype.onNodeCreated = function () {
|
||||
const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
|
||||
|
||||
@ -71,24 +71,21 @@ function graphEqual(a, b, root = true) {
|
||||
}
|
||||
|
||||
const undoRedo = async (e) => {
|
||||
const updateState = async (source, target) => {
|
||||
const prevState = source.pop();
|
||||
if (prevState) {
|
||||
target.push(activeState);
|
||||
isOurLoad = true;
|
||||
await app.loadGraphData(prevState, false);
|
||||
activeState = prevState;
|
||||
}
|
||||
}
|
||||
if (e.ctrlKey || e.metaKey) {
|
||||
if (e.key === "y") {
|
||||
const prevState = redo.pop();
|
||||
if (prevState) {
|
||||
undo.push(activeState);
|
||||
isOurLoad = true;
|
||||
await app.loadGraphData(prevState);
|
||||
activeState = prevState;
|
||||
}
|
||||
updateState(redo, undo);
|
||||
return true;
|
||||
} else if (e.key === "z") {
|
||||
const prevState = undo.pop();
|
||||
if (prevState) {
|
||||
redo.push(activeState);
|
||||
isOurLoad = true;
|
||||
await app.loadGraphData(prevState);
|
||||
activeState = prevState;
|
||||
}
|
||||
updateState(undo, redo);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,10 +1,16 @@
|
||||
import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js";
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { applyTextReplacements } from "../../scripts/utils.js";
|
||||
|
||||
const CONVERTED_TYPE = "converted-widget";
|
||||
const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"];
|
||||
const CONFIG = Symbol();
|
||||
const GET_CONFIG = Symbol();
|
||||
const TARGET = Symbol(); // Used for reroutes to specify the real target widget
|
||||
|
||||
export function getWidgetConfig(slot) {
|
||||
return slot.widget[CONFIG] ?? slot.widget[GET_CONFIG]();
|
||||
}
|
||||
|
||||
function getConfig(widgetName) {
|
||||
const { nodeData } = this.constructor;
|
||||
@ -100,7 +106,6 @@ function getWidgetType(config) {
|
||||
return { type };
|
||||
}
|
||||
|
||||
|
||||
function isValidCombo(combo, obj) {
|
||||
// New input isnt a combo
|
||||
if (!(obj instanceof Array)) {
|
||||
@ -121,6 +126,31 @@ function isValidCombo(combo, obj) {
|
||||
return true;
|
||||
}
|
||||
|
||||
export function setWidgetConfig(slot, config, target) {
|
||||
if (!slot.widget) return;
|
||||
if (config) {
|
||||
slot.widget[GET_CONFIG] = () => config;
|
||||
slot.widget[TARGET] = target;
|
||||
} else {
|
||||
delete slot.widget;
|
||||
}
|
||||
|
||||
if (slot.link) {
|
||||
const link = app.graph.links[slot.link];
|
||||
if (link) {
|
||||
const originNode = app.graph.getNodeById(link.origin_id);
|
||||
if (originNode.type === "PrimitiveNode") {
|
||||
if (config) {
|
||||
originNode.recreateWidget();
|
||||
} else if(!app.configuringGraph) {
|
||||
originNode.disconnectOutput(0);
|
||||
originNode.onLastDisconnect();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) {
|
||||
if (!config1) {
|
||||
config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
|
||||
@ -150,7 +180,7 @@ export function mergeIfValid(output, config2, forceUpdate, recreateWidget, confi
|
||||
|
||||
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
|
||||
for (const k of keys.values()) {
|
||||
if (k !== "default" && k !== "forceInput" && k !== "defaultInput") {
|
||||
if (k !== "default" && k !== "forceInput" && k !== "defaultInput" && k !== "control_after_generate" && k !== "multiline") {
|
||||
let v1 = config1[1][k];
|
||||
let v2 = config2[1]?.[k];
|
||||
|
||||
@ -405,11 +435,16 @@ app.registerExtension({
|
||||
};
|
||||
},
|
||||
registerCustomNodes() {
|
||||
const replacePropertyName = "Run widget replace on values";
|
||||
class PrimitiveNode {
|
||||
constructor() {
|
||||
this.addOutput("connect to widget input", "*");
|
||||
this.serialize_widgets = true;
|
||||
this.isVirtualNode = true;
|
||||
|
||||
if (!this.properties || !(replacePropertyName in this.properties)) {
|
||||
this.addProperty(replacePropertyName, false, "boolean");
|
||||
}
|
||||
}
|
||||
|
||||
applyToGraph(extraLinks = []) {
|
||||
@ -430,18 +465,29 @@ app.registerExtension({
|
||||
}
|
||||
|
||||
let links = [...get_links(this).map((l) => app.graph.links[l]), ...extraLinks];
|
||||
let v = this.widgets?.[0].value;
|
||||
if(v && this.properties[replacePropertyName]) {
|
||||
v = applyTextReplacements(app, v);
|
||||
}
|
||||
|
||||
// For each output link copy our value over the original widget value
|
||||
for (const linkInfo of links) {
|
||||
const node = this.graph.getNodeById(linkInfo.target_id);
|
||||
const input = node.inputs[linkInfo.target_slot];
|
||||
const widgetName = input.widget.name;
|
||||
if (widgetName) {
|
||||
const widget = node.widgets.find((w) => w.name === widgetName);
|
||||
if (widget) {
|
||||
widget.value = this.widgets[0].value;
|
||||
if (widget.callback) {
|
||||
widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {});
|
||||
}
|
||||
let widget;
|
||||
if (input.widget[TARGET]) {
|
||||
widget = input.widget[TARGET];
|
||||
} else {
|
||||
const widgetName = input.widget.name;
|
||||
if (widgetName) {
|
||||
widget = node.widgets.find((w) => w.name === widgetName);
|
||||
}
|
||||
}
|
||||
|
||||
if (widget) {
|
||||
widget.value = v;
|
||||
if (widget.callback) {
|
||||
widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {});
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -494,14 +540,13 @@ app.registerExtension({
|
||||
this.#mergeWidgetConfig();
|
||||
|
||||
if (!links?.length) {
|
||||
this.#onLastDisconnect();
|
||||
this.onLastDisconnect();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
onConnectOutput(slot, type, input, target_node, target_slot) {
|
||||
// Fires before the link is made allowing us to reject it if it isn't valid
|
||||
|
||||
// No widget, we cant connect
|
||||
if (!input.widget) {
|
||||
if (!(input.type in ComfyWidgets)) return false;
|
||||
@ -519,6 +564,10 @@ app.registerExtension({
|
||||
|
||||
#onFirstConnection(recreating) {
|
||||
// First connection can fire before the graph is ready on initial load so random things can be missing
|
||||
if (!this.outputs[0].links) {
|
||||
this.onLastDisconnect();
|
||||
return;
|
||||
}
|
||||
const linkId = this.outputs[0].links[0];
|
||||
const link = this.graph.links[linkId];
|
||||
if (!link) return;
|
||||
@ -546,10 +595,10 @@ app.registerExtension({
|
||||
this.outputs[0].name = type;
|
||||
this.outputs[0].widget = widget;
|
||||
|
||||
this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating);
|
||||
this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating, widget[TARGET]);
|
||||
}
|
||||
|
||||
#createWidget(inputData, node, widgetName, recreating) {
|
||||
#createWidget(inputData, node, widgetName, recreating, targetWidget) {
|
||||
let type = inputData[0];
|
||||
|
||||
if (type instanceof Array) {
|
||||
@ -563,7 +612,9 @@ app.registerExtension({
|
||||
widget = this.addWidget(type, "value", null, () => {}, {});
|
||||
}
|
||||
|
||||
if (node?.widgets && widget) {
|
||||
if (targetWidget) {
|
||||
widget.value = targetWidget.value;
|
||||
} else if (node?.widgets && widget) {
|
||||
const theirWidget = node.widgets.find((w) => w.name === widgetName);
|
||||
if (theirWidget) {
|
||||
widget.value = theirWidget.value;
|
||||
@ -577,11 +628,19 @@ app.registerExtension({
|
||||
}
|
||||
addValueControlWidgets(this, widget, control_value, undefined, inputData);
|
||||
let filter = this.widgets_values?.[2];
|
||||
if(filter && this.widgets.length === 3) {
|
||||
if (filter && this.widgets.length === 3) {
|
||||
this.widgets[2].value = filter;
|
||||
}
|
||||
}
|
||||
|
||||
// Restore any saved control values
|
||||
const controlValues = this.controlValues;
|
||||
if(this.lastType === this.widgets[0].type && controlValues?.length === this.widgets.length - 1) {
|
||||
for(let i = 0; i < controlValues.length; i++) {
|
||||
this.widgets[i + 1].value = controlValues[i];
|
||||
}
|
||||
}
|
||||
|
||||
// When our value changes, update other widgets to reflect our changes
|
||||
// e.g. so LoadImage shows correct image
|
||||
const callback = widget.callback;
|
||||
@ -610,12 +669,14 @@ app.registerExtension({
|
||||
}
|
||||
}
|
||||
|
||||
#recreateWidget() {
|
||||
const values = this.widgets.map((w) => w.value);
|
||||
recreateWidget() {
|
||||
const values = this.widgets?.map((w) => w.value);
|
||||
this.#removeWidgets();
|
||||
this.#onFirstConnection(true);
|
||||
for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i];
|
||||
return this.widgets[0];
|
||||
if (values?.length) {
|
||||
for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i];
|
||||
}
|
||||
return this.widgets?.[0];
|
||||
}
|
||||
|
||||
#mergeWidgetConfig() {
|
||||
@ -631,7 +692,7 @@ app.registerExtension({
|
||||
if (links?.length < 2 && hasConfig) {
|
||||
// Copy the widget options from the source
|
||||
if (links.length) {
|
||||
this.#recreateWidget();
|
||||
this.recreateWidget();
|
||||
}
|
||||
|
||||
return;
|
||||
@ -657,7 +718,7 @@ app.registerExtension({
|
||||
// Only allow connections where the configs match
|
||||
const output = this.outputs[0];
|
||||
const config2 = input.widget[GET_CONFIG]();
|
||||
return !!mergeIfValid.call(this, output, config2, forceUpdate, this.#recreateWidget);
|
||||
return !!mergeIfValid.call(this, output, config2, forceUpdate, this.recreateWidget);
|
||||
}
|
||||
|
||||
#removeWidgets() {
|
||||
@ -668,11 +729,20 @@ app.registerExtension({
|
||||
w.onRemove();
|
||||
}
|
||||
}
|
||||
|
||||
// Temporarily store the current values in case the node is being recreated
|
||||
// e.g. by group node conversion
|
||||
this.controlValues = [];
|
||||
this.lastType = this.widgets[0]?.type;
|
||||
for(let i = 1; i < this.widgets.length; i++) {
|
||||
this.controlValues.push(this.widgets[i].value);
|
||||
}
|
||||
setTimeout(() => { delete this.lastType; delete this.controlValues }, 15);
|
||||
this.widgets.length = 0;
|
||||
}
|
||||
}
|
||||
|
||||
#onLastDisconnect() {
|
||||
onLastDisconnect() {
|
||||
// We cant remove + re-add the output here as if you drag a link over the same link
|
||||
// it removes, then re-adds, causing it to break
|
||||
this.outputs[0].type = "*";
|
||||
|
||||
@ -48,7 +48,7 @@
|
||||
EVENT_LINK_COLOR: "#A86",
|
||||
CONNECTING_LINK_COLOR: "#AFA",
|
||||
|
||||
MAX_NUMBER_OF_NODES: 1000, //avoid infinite loops
|
||||
MAX_NUMBER_OF_NODES: 10000, //avoid infinite loops
|
||||
DEFAULT_POSITION: [100, 100], //default node position
|
||||
VALID_SHAPES: ["default", "box", "round", "card"], //,"circle"
|
||||
|
||||
@ -3788,16 +3788,42 @@
|
||||
|
||||
/**
|
||||
* returns the bounding of the object, used for rendering purposes
|
||||
* bounding is: [topleft_cornerx, topleft_cornery, width, height]
|
||||
* @method getBounding
|
||||
* @return {Float32Array[4]} the total size
|
||||
* @param out {Float32Array[4]?} [optional] a place to store the output, to free garbage
|
||||
* @param compute_outer {boolean?} [optional] set to true to include the shadow and connection points in the bounding calculation
|
||||
* @return {Float32Array[4]} the bounding box in format of [topleft_cornerx, topleft_cornery, width, height]
|
||||
*/
|
||||
LGraphNode.prototype.getBounding = function(out) {
|
||||
LGraphNode.prototype.getBounding = function(out, compute_outer) {
|
||||
out = out || new Float32Array(4);
|
||||
out[0] = this.pos[0] - 4;
|
||||
out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT;
|
||||
out[2] = this.flags.collapsed ? (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) : this.size[0] + 4;
|
||||
out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT;
|
||||
const nodePos = this.pos;
|
||||
const isCollapsed = this.flags.collapsed;
|
||||
const nodeSize = this.size;
|
||||
|
||||
let left_offset = 0;
|
||||
// 1 offset due to how nodes are rendered
|
||||
let right_offset = 1 ;
|
||||
let top_offset = 0;
|
||||
let bottom_offset = 0;
|
||||
|
||||
if (compute_outer) {
|
||||
// 4 offset for collapsed node connection points
|
||||
left_offset = 4;
|
||||
// 6 offset for right shadow and collapsed node connection points
|
||||
right_offset = 6 + left_offset;
|
||||
// 4 offset for collapsed nodes top connection points
|
||||
top_offset = 4;
|
||||
// 5 offset for bottom shadow and collapsed node connection points
|
||||
bottom_offset = 5 + top_offset;
|
||||
}
|
||||
|
||||
out[0] = nodePos[0] - left_offset;
|
||||
out[1] = nodePos[1] - LiteGraph.NODE_TITLE_HEIGHT - top_offset;
|
||||
out[2] = isCollapsed ?
|
||||
(this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) + right_offset :
|
||||
nodeSize[0] + right_offset;
|
||||
out[3] = isCollapsed ?
|
||||
LiteGraph.NODE_TITLE_HEIGHT + bottom_offset :
|
||||
nodeSize[1] + LiteGraph.NODE_TITLE_HEIGHT + bottom_offset;
|
||||
|
||||
if (this.onBounding) {
|
||||
this.onBounding(out);
|
||||
@ -7674,7 +7700,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!overlapBounding(this.visible_area, n.getBounding(temp))) {
|
||||
if (!overlapBounding(this.visible_area, n.getBounding(temp, true))) {
|
||||
continue;
|
||||
} //out of the visible area
|
||||
|
||||
@ -11336,6 +11362,7 @@ LGraphNode.prototype.executeAction = function(action)
|
||||
name_element.innerText = title;
|
||||
var value_element = dialog.querySelector(".value");
|
||||
value_element.value = value;
|
||||
value_element.select();
|
||||
|
||||
var input = value_element;
|
||||
input.addEventListener("keydown", function(e) {
|
||||
|
||||
@ -1559,9 +1559,12 @@ export class ComfyApp {
|
||||
/**
|
||||
* Populates the graph with the specified workflow data
|
||||
* @param {*} graphData A serialized graph object
|
||||
* @param { boolean } clean If the graph state, e.g. images, should be cleared
|
||||
*/
|
||||
async loadGraphData(graphData) {
|
||||
this.clean();
|
||||
async loadGraphData(graphData, clean = true) {
|
||||
if (clean !== false) {
|
||||
this.clean();
|
||||
}
|
||||
|
||||
let reset_invalid_values = false;
|
||||
if (!graphData) {
|
||||
@ -1771,15 +1774,26 @@ export class ComfyApp {
|
||||
if (parent?.updateLink) {
|
||||
link = parent.updateLink(link);
|
||||
}
|
||||
inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
|
||||
if (link) {
|
||||
inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output[String(node.id)] = {
|
||||
let node_data = {
|
||||
inputs,
|
||||
class_type: node.comfyClass,
|
||||
};
|
||||
|
||||
if (this.ui.settings.getSettingValue("Comfy.DevMode")) {
|
||||
// Ignored by the backend.
|
||||
node_data["_meta"] = {
|
||||
title: node.title,
|
||||
}
|
||||
}
|
||||
|
||||
output[String(node.id)] = node_data;
|
||||
}
|
||||
}
|
||||
|
||||
@ -2006,12 +2020,8 @@ export class ComfyApp {
|
||||
async refreshComboInNodes() {
|
||||
const defs = await api.getNodeDefs();
|
||||
|
||||
for(const nodeId in LiteGraph.registered_node_types) {
|
||||
const node = LiteGraph.registered_node_types[nodeId];
|
||||
const nodeDef = defs[nodeId];
|
||||
if(!nodeDef) continue;
|
||||
|
||||
node.nodeData = nodeDef;
|
||||
for (const nodeId in defs) {
|
||||
this.registerNodeDef(nodeId, defs[nodeId]);
|
||||
}
|
||||
|
||||
for(let nodeNum in this.graph._nodes) {
|
||||
|
||||
@ -177,6 +177,7 @@ LGraphCanvas.prototype.computeVisibleNodes = function () {
|
||||
for (const w of node.widgets) {
|
||||
if (w.element) {
|
||||
w.element.hidden = hidden;
|
||||
w.element.style.display = hidden ? "none" : undefined;
|
||||
if (hidden) {
|
||||
w.options.onHide?.(w);
|
||||
}
|
||||
|
||||
67
web/scripts/utils.js
Normal file
67
web/scripts/utils.js
Normal file
@ -0,0 +1,67 @@
|
||||
// Simple date formatter
|
||||
const parts = {
|
||||
d: (d) => d.getDate(),
|
||||
M: (d) => d.getMonth() + 1,
|
||||
h: (d) => d.getHours(),
|
||||
m: (d) => d.getMinutes(),
|
||||
s: (d) => d.getSeconds(),
|
||||
};
|
||||
const format =
|
||||
Object.keys(parts)
|
||||
.map((k) => k + k + "?")
|
||||
.join("|") + "|yyy?y?";
|
||||
|
||||
function formatDate(text, date) {
|
||||
return text.replace(new RegExp(format, "g"), function (text) {
|
||||
if (text === "yy") return (date.getFullYear() + "").substring(2);
|
||||
if (text === "yyyy") return date.getFullYear();
|
||||
if (text[0] in parts) {
|
||||
const p = parts[text[0]](date);
|
||||
return (p + "").padStart(text.length, "0");
|
||||
}
|
||||
return text;
|
||||
});
|
||||
}
|
||||
|
||||
export function applyTextReplacements(app, value) {
|
||||
return value.replace(/%([^%]+)%/g, function (match, text) {
|
||||
const split = text.split(".");
|
||||
if (split.length !== 2) {
|
||||
// Special handling for dates
|
||||
if (split[0].startsWith("date:")) {
|
||||
return formatDate(split[0].substring(5), new Date());
|
||||
}
|
||||
|
||||
if (text !== "width" && text !== "height") {
|
||||
// Dont warn on standard replacements
|
||||
console.warn("Invalid replacement pattern", text);
|
||||
}
|
||||
return match;
|
||||
}
|
||||
|
||||
// Find node with matching S&R property name
|
||||
let nodes = app.graph._nodes.filter((n) => n.properties?.["Node name for S&R"] === split[0]);
|
||||
// If we cant, see if there is a node with that title
|
||||
if (!nodes.length) {
|
||||
nodes = app.graph._nodes.filter((n) => n.title === split[0]);
|
||||
}
|
||||
if (!nodes.length) {
|
||||
console.warn("Unable to find node", split[0]);
|
||||
return match;
|
||||
}
|
||||
|
||||
if (nodes.length > 1) {
|
||||
console.warn("Multiple nodes matched", split[0], "using first match");
|
||||
}
|
||||
|
||||
const node = nodes[0];
|
||||
|
||||
const widget = node.widgets?.find((w) => w.name === split[1]);
|
||||
if (!widget) {
|
||||
console.warn("Unable to find widget", split[1], "on node", split[0], node);
|
||||
return match;
|
||||
}
|
||||
|
||||
return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_");
|
||||
});
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user