mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Merge upstream, fix 3.12 compatibility, fix nightlies issue, fix broken node
This commit is contained in:
commit
369aeb598f
@ -1,3 +0,0 @@
|
|||||||
..\python_embeded\python.exe .\update.py ..\ComfyUI\
|
|
||||||
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2
|
|
||||||
pause
|
|
||||||
@ -1,2 +0,0 @@
|
|||||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --use-pytorch-cross-attention
|
|
||||||
pause
|
|
||||||
2
.github/workflows/test-ui.yaml
vendored
2
.github/workflows/test-ui.yaml
vendored
@ -22,5 +22,5 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
npm ci
|
npm ci
|
||||||
npm run test:generate
|
npm run test:generate
|
||||||
npm test
|
npm test -- --verbose
|
||||||
working-directory: ./tests-ui
|
working-directory: ./tests-ui
|
||||||
|
|||||||
@ -2,6 +2,24 @@ name: "Windows Release Nightly pytorch"
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
cu:
|
||||||
|
description: 'cuda version'
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
default: "121"
|
||||||
|
|
||||||
|
python_minor:
|
||||||
|
description: 'python minor version'
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
default: "12"
|
||||||
|
|
||||||
|
python_patch:
|
||||||
|
description: 'python patch version'
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
default: "1"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
@ -20,21 +38,21 @@ jobs:
|
|||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: '3.11.6'
|
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
|
||||||
- shell: bash
|
- shell: bash
|
||||||
run: |
|
run: |
|
||||||
cd ..
|
cd ..
|
||||||
cp -r ComfyUI ComfyUI_copy
|
cp -r ComfyUI ComfyUI_copy
|
||||||
curl https://www.python.org/ftp/python/3.11.6/python-3.11.6-embed-amd64.zip -o python_embeded.zip
|
curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip
|
||||||
unzip python_embeded.zip -d python_embeded
|
unzip python_embeded.zip -d python_embeded
|
||||||
cd python_embeded
|
cd python_embeded
|
||||||
echo 'import site' >> ./python311._pth
|
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
|
||||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||||
./python.exe get-pip.py
|
./python.exe get-pip.py
|
||||||
python -m pip wheel torch torchvision torchaudio aiohttp==3.8.5 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
||||||
ls ../temp_wheel_dir
|
ls ../temp_wheel_dir
|
||||||
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
||||||
sed -i '1i../ComfyUI' ./python311._pth
|
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
git clone https://github.com/comfyanonymous/taesd
|
git clone https://github.com/comfyanonymous/taesd
|
||||||
@ -49,9 +67,10 @@ jobs:
|
|||||||
mkdir update
|
mkdir update
|
||||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||||
cp -r ComfyUI/.ci/nightly/update_windows/* ./update/
|
|
||||||
cp -r ComfyUI/.ci/nightly/windows_base_files/* ./
|
|
||||||
|
|
||||||
|
echo "..\python_embeded\python.exe .\update.py ..\ComfyUI\\
|
||||||
|
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
||||||
|
pause" > ./update/update_comfyui_and_python_dependencies.bat
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch
|
||||||
|
|||||||
9
.vscode/settings.json
vendored
9
.vscode/settings.json
vendored
@ -1,9 +0,0 @@
|
|||||||
{
|
|
||||||
"path-intellisense.mappings": {
|
|
||||||
"../": "${workspaceFolder}/web/extensions/core"
|
|
||||||
},
|
|
||||||
"[python]": {
|
|
||||||
"editor.defaultFormatter": "ms-python.autopep8"
|
|
||||||
},
|
|
||||||
"python.formatting.provider": "none"
|
|
||||||
}
|
|
||||||
0
__init__.py
Normal file
0
__init__.py
Normal file
@ -53,7 +53,7 @@ class ControlNet(nn.Module):
|
|||||||
transformer_depth_middle=None,
|
transformer_depth_middle=None,
|
||||||
transformer_depth_output=None,
|
transformer_depth_output=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=ops,
|
operations=ops.disable_weight_init,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -141,24 +141,24 @@ class ControlNet(nn.Module):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)])
|
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
|
||||||
|
|
||||||
self.input_hint_block = TimestepEmbedSequential(
|
self.input_hint_block = TimestepEmbedSequential(
|
||||||
operations.conv_nd(dims, hint_channels, 16, 3, padding=1),
|
operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, 16, 16, 3, padding=1),
|
operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, 32, 32, 3, padding=1),
|
operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, 96, 96, 3, padding=1),
|
operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1))
|
operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._feature_size = model_channels
|
self._feature_size = model_channels
|
||||||
@ -206,7 +206,7 @@ class ControlNet(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
|
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
input_block_chans.append(ch)
|
input_block_chans.append(ch)
|
||||||
if level != len(channel_mult) - 1:
|
if level != len(channel_mult) - 1:
|
||||||
@ -234,7 +234,7 @@ class ControlNet(nn.Module):
|
|||||||
)
|
)
|
||||||
ch = out_ch
|
ch = out_ch
|
||||||
input_block_chans.append(ch)
|
input_block_chans.append(ch)
|
||||||
self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
|
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
||||||
ds *= 2
|
ds *= 2
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
@ -276,14 +276,14 @@ class ControlNet(nn.Module):
|
|||||||
operations=operations
|
operations=operations
|
||||||
)]
|
)]
|
||||||
self.middle_block = TimestepEmbedSequential(*mid_block)
|
self.middle_block = TimestepEmbedSequential(*mid_block)
|
||||||
self.middle_block_out = self.make_zero_conv(ch, operations=operations)
|
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
def make_zero_conv(self, channels, operations=None):
|
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
|
||||||
return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
|
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
|
||||||
|
|
||||||
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
guided_hint = self.input_hint_block(hint, emb, context)
|
guided_hint = self.input_hint_block(hint, emb, context)
|
||||||
@ -295,7 +295,7 @@ class ControlNet(nn.Module):
|
|||||||
assert y.shape[0] == x.shape[0]
|
assert y.shape[0] == x.shape[0]
|
||||||
emb = emb + self.label_emb(y)
|
emb = emb + self.label_emb(y)
|
||||||
|
|
||||||
h = x.type(self.dtype)
|
h = x
|
||||||
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
||||||
if guided_hint is not None:
|
if guided_hint is not None:
|
||||||
h = module(h, emb, context)
|
h = module(h, emb, context)
|
||||||
|
|||||||
@ -55,13 +55,19 @@ fp_group = parser.add_mutually_exclusive_group()
|
|||||||
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||||
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||||
|
|
||||||
parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
fpunet_group = parser.add_mutually_exclusive_group()
|
||||||
|
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
||||||
|
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
|
||||||
|
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
||||||
|
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
||||||
|
|
||||||
fpvae_group = parser.add_mutually_exclusive_group()
|
fpvae_group = parser.add_mutually_exclusive_group()
|
||||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||||
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
||||||
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
|
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
|
||||||
|
|
||||||
|
parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
|
||||||
|
|
||||||
fpte_group = parser.add_mutually_exclusive_group()
|
fpte_group = parser.add_mutually_exclusive_group()
|
||||||
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
|
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
|
||||||
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
|
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
|
||||||
@ -98,7 +104,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
|
|||||||
|
|
||||||
|
|
||||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||||
|
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||||
|
|
||||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
|
|||||||
188
comfy/clip_model.py
Normal file
188
comfy/clip_model.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
import torch
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
|
|
||||||
|
class CLIPAttention(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.heads = heads
|
||||||
|
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, optimized_attention=None):
|
||||||
|
q = self.q_proj(x)
|
||||||
|
k = self.k_proj(x)
|
||||||
|
v = self.v_proj(x)
|
||||||
|
|
||||||
|
out = optimized_attention(q, k, v, self.heads, mask)
|
||||||
|
return self.out_proj(out)
|
||||||
|
|
||||||
|
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
||||||
|
"gelu": torch.nn.functional.gelu,
|
||||||
|
}
|
||||||
|
|
||||||
|
class CLIPMLP(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
|
||||||
|
self.activation = ACTIVATIONS[activation]
|
||||||
|
self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class CLIPLayer(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||||
|
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
|
||||||
|
self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||||
|
self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, optimized_attention=None):
|
||||||
|
x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
|
||||||
|
x += self.mlp(self.layer_norm2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, intermediate_output=None):
|
||||||
|
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None)
|
||||||
|
|
||||||
|
if intermediate_output is not None:
|
||||||
|
if intermediate_output < 0:
|
||||||
|
intermediate_output = len(self.layers) + intermediate_output
|
||||||
|
|
||||||
|
intermediate = None
|
||||||
|
for i, l in enumerate(self.layers):
|
||||||
|
x = l(x, mask, optimized_attention)
|
||||||
|
if i == intermediate_output:
|
||||||
|
intermediate = x.clone()
|
||||||
|
return x, intermediate
|
||||||
|
|
||||||
|
class CLIPEmbeddings(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
|
||||||
|
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, input_tokens):
|
||||||
|
return self.token_embedding(input_tokens) + self.position_embedding.weight
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextModel_(torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
num_layers = config_dict["num_hidden_layers"]
|
||||||
|
embed_dim = config_dict["hidden_size"]
|
||||||
|
heads = config_dict["num_attention_heads"]
|
||||||
|
intermediate_size = config_dict["intermediate_size"]
|
||||||
|
intermediate_activation = config_dict["hidden_act"]
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
|
||||||
|
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||||
|
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||||
|
x = self.embeddings(input_tokens)
|
||||||
|
mask = None
|
||||||
|
if attention_mask is not None:
|
||||||
|
mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||||
|
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||||
|
|
||||||
|
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||||
|
if mask is not None:
|
||||||
|
mask += causal_mask
|
||||||
|
else:
|
||||||
|
mask = causal_mask
|
||||||
|
|
||||||
|
x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
if i is not None and final_layer_norm_intermediate:
|
||||||
|
i = self.final_layer_norm(i)
|
||||||
|
|
||||||
|
pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
|
||||||
|
return x, i, pooled_output
|
||||||
|
|
||||||
|
class CLIPTextModel(torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.num_layers = config_dict["num_hidden_layers"]
|
||||||
|
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.text_model.embeddings.token_embedding
|
||||||
|
|
||||||
|
def set_input_embeddings(self, embeddings):
|
||||||
|
self.text_model.embeddings.token_embedding = embeddings
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.text_model(*args, **kwargs)
|
||||||
|
|
||||||
|
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
self.patch_embedding = operations.Conv2d(
|
||||||
|
in_channels=num_channels,
|
||||||
|
out_channels=embed_dim,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
bias=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
num_patches = (image_size // patch_size) ** 2
|
||||||
|
num_positions = num_patches + 1
|
||||||
|
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, pixel_values):
|
||||||
|
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
|
||||||
|
return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPVision(torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
num_layers = config_dict["num_hidden_layers"]
|
||||||
|
embed_dim = config_dict["hidden_size"]
|
||||||
|
heads = config_dict["num_attention_heads"]
|
||||||
|
intermediate_size = config_dict["intermediate_size"]
|
||||||
|
intermediate_activation = config_dict["hidden_act"]
|
||||||
|
|
||||||
|
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=torch.float32, device=device, operations=operations)
|
||||||
|
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
||||||
|
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||||
|
self.post_layernorm = operations.LayerNorm(embed_dim)
|
||||||
|
|
||||||
|
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||||
|
x = self.embeddings(pixel_values)
|
||||||
|
x = self.pre_layrnorm(x)
|
||||||
|
#TODO: attention_mask?
|
||||||
|
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
|
||||||
|
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||||
|
return x, i, pooled_output
|
||||||
|
|
||||||
|
class CLIPVisionModelProjection(torch.nn.Module):
|
||||||
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
|
||||||
|
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
x = self.vision_model(*args, **kwargs)
|
||||||
|
out = self.visual_projection(x[2])
|
||||||
|
return (x[0], x[1], out)
|
||||||
@ -1,36 +1,43 @@
|
|||||||
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils
|
|
||||||
from .utils import load_torch_file, transformers_convert
|
from .utils import load_torch_file, transformers_convert
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import contextlib
|
import json
|
||||||
|
|
||||||
from . import ops
|
from . import ops
|
||||||
from . import model_patcher
|
from . import model_patcher
|
||||||
from . import model_management
|
from . import model_management
|
||||||
|
from . import clip_model
|
||||||
|
|
||||||
|
|
||||||
|
class Output:
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return getattr(self, key)
|
||||||
|
def __setitem__(self, key, item):
|
||||||
|
setattr(self, key, item)
|
||||||
|
|
||||||
def clip_preprocess(image, size=224):
|
def clip_preprocess(image, size=224):
|
||||||
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
|
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
|
||||||
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
|
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
|
||||||
scale = (size / min(image.shape[1], image.shape[2]))
|
image = image.movedim(-1, 1)
|
||||||
image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True)
|
if not (image.shape[2] == size and image.shape[3] == size):
|
||||||
h = (image.shape[2] - size)//2
|
scale = (size / min(image.shape[2], image.shape[3]))
|
||||||
w = (image.shape[3] - size)//2
|
image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
|
||||||
image = image[:,:,h:h+size,w:w+size]
|
h = (image.shape[2] - size)//2
|
||||||
|
w = (image.shape[3] - size)//2
|
||||||
|
image = image[:,:,h:h+size,w:w+size]
|
||||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||||
|
|
||||||
class ClipVisionModel():
|
class ClipVisionModel():
|
||||||
def __init__(self, json_config):
|
def __init__(self, json_config):
|
||||||
config = CLIPVisionConfig.from_json_file(json_config)
|
with open(json_config) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
self.load_device = model_management.text_encoder_device()
|
self.load_device = model_management.text_encoder_device()
|
||||||
offload_device = model_management.text_encoder_offload_device()
|
offload_device = model_management.text_encoder_offload_device()
|
||||||
self.dtype = torch.float32
|
self.dtype = model_management.text_encoder_dtype(self.load_device)
|
||||||
if model_management.should_use_fp16(self.load_device, prioritize_performance=False):
|
self.model = clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, ops.manual_cast)
|
||||||
self.dtype = torch.float16
|
self.model.eval()
|
||||||
|
|
||||||
with ops.use_comfy_ops(offload_device, self.dtype):
|
|
||||||
with modeling_utils.no_init_weights():
|
|
||||||
self.model = CLIPVisionModelWithProjection(config)
|
|
||||||
self.model.to(self.dtype)
|
|
||||||
|
|
||||||
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
@ -38,25 +45,13 @@ class ClipVisionModel():
|
|||||||
|
|
||||||
def encode_image(self, image):
|
def encode_image(self, image):
|
||||||
model_management.load_model_gpu(self.patcher)
|
model_management.load_model_gpu(self.patcher)
|
||||||
pixel_values = clip_preprocess(image.to(self.load_device))
|
pixel_values = clip_preprocess(image.to(self.load_device)).float()
|
||||||
|
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||||
if self.dtype != torch.float32:
|
|
||||||
precision_scope = torch.autocast
|
|
||||||
else:
|
|
||||||
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
|
||||||
|
|
||||||
with precision_scope(model_management.get_autocast_device(self.load_device), torch.float32):
|
|
||||||
outputs = self.model(pixel_values=pixel_values, output_hidden_states=True)
|
|
||||||
|
|
||||||
for k in outputs:
|
|
||||||
t = outputs[k]
|
|
||||||
if t is not None:
|
|
||||||
if k == 'hidden_states':
|
|
||||||
outputs["penultimate_hidden_states"] = t[-2].cpu()
|
|
||||||
outputs["hidden_states"] = None
|
|
||||||
else:
|
|
||||||
outputs[k] = t.cpu()
|
|
||||||
|
|
||||||
|
outputs = Output()
|
||||||
|
outputs["last_hidden_state"] = out[0].to(model_management.intermediate_device())
|
||||||
|
outputs["image_embeds"] = out[2].to(model_management.intermediate_device())
|
||||||
|
outputs["penultimate_hidden_states"] = out[1].to(model_management.intermediate_device())
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def convert_to_transformers(sd, prefix):
|
def convert_to_transformers(sd, prefix):
|
||||||
@ -86,6 +81,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
if convert_keys:
|
if convert_keys:
|
||||||
sd = convert_to_transformers(sd, prefix)
|
sd = convert_to_transformers(sd, prefix)
|
||||||
if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
|
if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
|
||||||
|
# todo: fix the importlib issue here
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
|
||||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||||
|
|||||||
@ -13,6 +13,8 @@ import typing
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
import sys
|
import sys
|
||||||
|
import gc
|
||||||
|
import inspect
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -442,6 +444,8 @@ class PromptExecutor:
|
|||||||
for x in executed:
|
for x in executed:
|
||||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||||
self.server.last_node_id = None
|
self.server.last_node_id = None
|
||||||
|
if model_management.DISABLE_SMART_MEMORY:
|
||||||
|
model_management.unload_all_models()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -462,6 +466,14 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
|
|||||||
errors = []
|
errors = []
|
||||||
valid = True
|
valid = True
|
||||||
|
|
||||||
|
# todo: investigate if these are at the right indent level
|
||||||
|
info = None
|
||||||
|
val = None
|
||||||
|
|
||||||
|
validate_function_inputs = []
|
||||||
|
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||||
|
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
|
||||||
|
|
||||||
for x in required_inputs:
|
for x in required_inputs:
|
||||||
if x not in inputs:
|
if x not in inputs:
|
||||||
error = {
|
error = {
|
||||||
@ -591,29 +603,7 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
if x not in validate_function_inputs:
|
||||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
|
||||||
# ret = obj_class.VALIDATE_INPUTS(**input_data_all)
|
|
||||||
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
|
|
||||||
for i, r3 in enumerate(ret):
|
|
||||||
if r3 is not True:
|
|
||||||
details = f"{x}"
|
|
||||||
if r3 is not False:
|
|
||||||
details += f" - {str(r3)}"
|
|
||||||
|
|
||||||
error = {
|
|
||||||
"type": "custom_validation_failed",
|
|
||||||
"message": "Custom validation failed for node",
|
|
||||||
"details": details,
|
|
||||||
"extra_info": {
|
|
||||||
"input_name": x,
|
|
||||||
"input_config": info,
|
|
||||||
"received_value": val,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
errors.append(error)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
if isinstance(type_input, list):
|
if isinstance(type_input, list):
|
||||||
if val not in type_input:
|
if val not in type_input:
|
||||||
input_config = info
|
input_config = info
|
||||||
@ -640,6 +630,35 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if len(validate_function_inputs) > 0:
|
||||||
|
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||||
|
input_filtered = {}
|
||||||
|
for x in input_data_all:
|
||||||
|
if x in validate_function_inputs:
|
||||||
|
input_filtered[x] = input_data_all[x]
|
||||||
|
|
||||||
|
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
||||||
|
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||||
|
for x in input_filtered:
|
||||||
|
for i, r in enumerate(ret):
|
||||||
|
if r is not True:
|
||||||
|
details = f"{x}"
|
||||||
|
if r is not False:
|
||||||
|
details += f" - {str(r)}"
|
||||||
|
|
||||||
|
error = {
|
||||||
|
"type": "custom_validation_failed",
|
||||||
|
"message": "Custom validation failed for node",
|
||||||
|
"details": details,
|
||||||
|
"extra_info": {
|
||||||
|
"input_name": x,
|
||||||
|
"input_config": info,
|
||||||
|
"received_value": val,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errors.append(error)
|
||||||
|
continue
|
||||||
|
|
||||||
if len(errors) > 0 or valid is not True:
|
if len(errors) > 0 or valid is not True:
|
||||||
ret = (False, errors, unique_id)
|
ret = (False, errors, unique_id)
|
||||||
else:
|
else:
|
||||||
@ -771,7 +790,7 @@ class PromptQueue:
|
|||||||
self.server.queue_updated()
|
self.server.queue_updated()
|
||||||
self.not_empty.notify()
|
self.not_empty.notify()
|
||||||
|
|
||||||
def get(self, timeout=None) -> typing.Tuple[QueueTuple, int]:
|
def get(self, timeout=None) -> typing.Optional[typing.Tuple[QueueTuple, int]]:
|
||||||
with self.not_empty:
|
with self.not_empty:
|
||||||
while len(self.queue) == 0:
|
while len(self.queue) == 0:
|
||||||
self.not_empty.wait(timeout=timeout)
|
self.not_empty.wait(timeout=timeout)
|
||||||
|
|||||||
@ -188,8 +188,7 @@ def cached_filename_list_(folder_name):
|
|||||||
if folder_name not in filename_list_cache:
|
if folder_name not in filename_list_cache:
|
||||||
return None
|
return None
|
||||||
out = filename_list_cache[folder_name]
|
out = filename_list_cache[folder_name]
|
||||||
if time.perf_counter() < (out[2] + 0.5):
|
|
||||||
return out
|
|
||||||
for x in out[1]:
|
for x in out[1]:
|
||||||
time_modified = out[1][x]
|
time_modified = out[1][x]
|
||||||
folder = x
|
folder = x
|
||||||
|
|||||||
@ -23,9 +23,9 @@ def execute_prestartup_script():
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
node_paths = folder_paths.get_folder_paths("custom_nodes")
|
||||||
|
node_prestartup_times = []
|
||||||
for custom_node_path in node_paths:
|
for custom_node_path in node_paths:
|
||||||
possible_modules = os.listdir(custom_node_path) if os.path.exists(custom_node_path) else []
|
possible_modules = os.listdir(custom_node_path) if os.path.exists(custom_node_path) else []
|
||||||
node_prestartup_times = []
|
|
||||||
|
|
||||||
for possible_module in possible_modules:
|
for possible_module in possible_modules:
|
||||||
module_path = os.path.join(custom_node_path, possible_module)
|
module_path = os.path.join(custom_node_path, possible_module)
|
||||||
@ -69,6 +69,10 @@ if args.cuda_device is not None:
|
|||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||||
print("Set cuda device to:", args.cuda_device)
|
print("Set cuda device to:", args.cuda_device)
|
||||||
|
|
||||||
|
if args.deterministic:
|
||||||
|
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
||||||
|
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
||||||
|
|
||||||
from .. import utils
|
from .. import utils
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -78,12 +82,12 @@ from .server import BinaryEventTypes
|
|||||||
from .. import model_management
|
from .. import model_management
|
||||||
|
|
||||||
|
|
||||||
def prompt_worker(q: execution.PromptQueue, _server: server_module.PromptServer):
|
def prompt_worker(q, _server):
|
||||||
e = execution.PromptExecutor(_server)
|
e = execution.PromptExecutor(_server)
|
||||||
last_gc_collect = 0
|
last_gc_collect = 0
|
||||||
need_gc = False
|
need_gc = False
|
||||||
gc_collect_interval = 10.0
|
gc_collect_interval = 10.0
|
||||||
|
current_time = 0.0
|
||||||
while True:
|
while True:
|
||||||
timeout = None
|
timeout = None
|
||||||
if need_gc:
|
if need_gc:
|
||||||
@ -94,11 +98,13 @@ def prompt_worker(q: execution.PromptQueue, _server: server_module.PromptServer)
|
|||||||
item, item_id = queue_item
|
item, item_id = queue_item
|
||||||
execution_start_time = time.perf_counter()
|
execution_start_time = time.perf_counter()
|
||||||
prompt_id = item[1]
|
prompt_id = item[1]
|
||||||
|
_server.last_prompt_id = prompt_id
|
||||||
|
|
||||||
e.execute(item[2], prompt_id, item[3], item[4])
|
e.execute(item[2], prompt_id, item[3], item[4])
|
||||||
need_gc = True
|
need_gc = True
|
||||||
q.task_done(item_id, e.outputs_ui)
|
q.task_done(item_id, e.outputs_ui)
|
||||||
if _server.client_id is not None:
|
if _server.client_id is not None:
|
||||||
_server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, _server.client_id)
|
_server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, _server.client_id)
|
||||||
|
|
||||||
current_time = time.perf_counter()
|
current_time = time.perf_counter()
|
||||||
execution_time = current_time - execution_start_time
|
execution_time = current_time - execution_start_time
|
||||||
@ -119,7 +125,10 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
|||||||
|
|
||||||
def hijack_progress(server):
|
def hijack_progress(server):
|
||||||
def hook(value, total, preview_image):
|
def hook(value, total, preview_image):
|
||||||
server.send_sync("progress", {"value": value, "max": total}, server.client_id)
|
model_management.throw_exception_if_processing_interrupted()
|
||||||
|
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
|
||||||
|
|
||||||
|
server.send_sync("progress", progress, server.client_id)
|
||||||
if preview_image is not None:
|
if preview_image is not None:
|
||||||
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
|
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
|
||||||
|
|
||||||
@ -204,7 +213,7 @@ def main():
|
|||||||
print(f"Setting output directory to: {output_dir}")
|
print(f"Setting output directory to: {output_dir}")
|
||||||
folder_paths.set_output_directory(output_dir)
|
folder_paths.set_output_directory(output_dir)
|
||||||
|
|
||||||
#These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
|
# These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
|
||||||
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
|
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
|
||||||
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
||||||
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
||||||
|
|||||||
@ -734,7 +734,8 @@ class PromptServer():
|
|||||||
message = self.encode_bytes(event, data)
|
message = self.encode_bytes(event, data)
|
||||||
|
|
||||||
if sid is None:
|
if sid is None:
|
||||||
for ws in self.sockets.values():
|
sockets = list(self.sockets.values())
|
||||||
|
for ws in sockets:
|
||||||
await send_socket_catch_exception(ws.send_bytes, message)
|
await send_socket_catch_exception(ws.send_bytes, message)
|
||||||
elif sid in self.sockets:
|
elif sid in self.sockets:
|
||||||
await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
|
await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
|
||||||
@ -743,7 +744,8 @@ class PromptServer():
|
|||||||
message = {"type": event, "data": data}
|
message = {"type": event, "data": data}
|
||||||
|
|
||||||
if sid is None:
|
if sid is None:
|
||||||
for ws in self.sockets.values():
|
sockets = list(self.sockets.values())
|
||||||
|
for ws in sockets:
|
||||||
await send_socket_catch_exception(ws.send_json, message)
|
await send_socket_catch_exception(ws.send_json, message)
|
||||||
elif sid in self.sockets:
|
elif sid in self.sockets:
|
||||||
await send_socket_catch_exception(self.sockets[sid].send_json, message)
|
await send_socket_catch_exception(self.sockets[sid].send_json, message)
|
||||||
|
|||||||
@ -1,10 +1,13 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import contextlib
|
||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
from . import model_management
|
from . import model_management
|
||||||
from . import model_detection
|
from . import model_detection
|
||||||
from . import model_patcher
|
from . import model_patcher
|
||||||
|
from . import ops
|
||||||
|
|
||||||
from .cldm import cldm
|
from .cldm import cldm
|
||||||
from .t2i_adapter import adapter
|
from .t2i_adapter import adapter
|
||||||
@ -34,13 +37,13 @@ class ControlBase:
|
|||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.strength = 1.0
|
self.strength = 1.0
|
||||||
self.timestep_percent_range = (0.0, 1.0)
|
self.timestep_percent_range = (0.0, 1.0)
|
||||||
|
self.global_average_pooling = False
|
||||||
self.timestep_range = None
|
self.timestep_range = None
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = model_management.get_torch_device()
|
device = model_management.get_torch_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.previous_controlnet = None
|
self.previous_controlnet = None
|
||||||
self.global_average_pooling = False
|
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
@ -75,6 +78,7 @@ class ControlBase:
|
|||||||
c.cond_hint_original = self.cond_hint_original
|
c.cond_hint_original = self.cond_hint_original
|
||||||
c.strength = self.strength
|
c.strength = self.strength
|
||||||
c.timestep_percent_range = self.timestep_percent_range
|
c.timestep_percent_range = self.timestep_percent_range
|
||||||
|
c.global_average_pooling = self.global_average_pooling
|
||||||
|
|
||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
@ -127,12 +131,14 @@ class ControlBase:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model, global_average_pooling=False, device=None):
|
def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
|
self.load_device = load_device
|
||||||
|
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
self.model_sampling_current = None
|
self.model_sampling_current = None
|
||||||
|
self.manual_cast_dtype = manual_cast_dtype
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
@ -146,28 +152,31 @@ class ControlNet(ControlBase):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
dtype = self.control_model.dtype
|
||||||
|
if self.manual_cast_dtype is not None:
|
||||||
|
dtype = self.manual_cast_dtype
|
||||||
|
|
||||||
output_dtype = x_noisy.dtype
|
output_dtype = x_noisy.dtype
|
||||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
|
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
|
||||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
|
|
||||||
|
|
||||||
context = cond['c_crossattn']
|
context = cond['c_crossattn']
|
||||||
y = cond.get('y', None)
|
y = cond.get('y', None)
|
||||||
if y is not None:
|
if y is not None:
|
||||||
y = y.to(self.control_model.dtype)
|
y = y.to(dtype)
|
||||||
timestep = self.model_sampling_current.timestep(t)
|
timestep = self.model_sampling_current.timestep(t)
|
||||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||||
|
|
||||||
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
||||||
return self.control_merge(None, control, control_prev, output_dtype)
|
return self.control_merge(None, control, control_prev, output_dtype)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
|
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||||
self.copy_to(c)
|
self.copy_to(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@ -198,10 +207,11 @@ class ControlLoraOps:
|
|||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
weight, bias = ops.cast_bias_weight(self, input)
|
||||||
if self.up is not None:
|
if self.up is not None:
|
||||||
return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
|
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
class Conv2d(torch.nn.Module):
|
class Conv2d(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -237,16 +247,11 @@ class ControlLoraOps:
|
|||||||
|
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
weight, bias = ops.cast_bias_weight(self, input)
|
||||||
if self.up is not None:
|
if self.up is not None:
|
||||||
return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
def conv_nd(self, dims, *args, **kwargs):
|
|
||||||
if dims == 2:
|
|
||||||
return self.Conv2d(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
|
||||||
|
|
||||||
|
|
||||||
class ControlLora(ControlNet):
|
class ControlLora(ControlNet):
|
||||||
@ -260,17 +265,26 @@ class ControlLora(ControlNet):
|
|||||||
controlnet_config = model.model_config.unet_config.copy()
|
controlnet_config = model.model_config.unet_config.copy()
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
|
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
|
||||||
controlnet_config["operations"] = ControlLoraOps()
|
self.manual_cast_dtype = model.manual_cast_dtype
|
||||||
self.control_model = cldm.ControlNet(**controlnet_config)
|
|
||||||
dtype = model.get_dtype()
|
dtype = model.get_dtype()
|
||||||
self.control_model.to(dtype)
|
if self.manual_cast_dtype is None:
|
||||||
|
class control_lora_ops(ControlLoraOps, ops.disable_weight_init):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
class control_lora_ops(ControlLoraOps, ops.manual_cast):
|
||||||
|
pass
|
||||||
|
dtype = self.manual_cast_dtype
|
||||||
|
|
||||||
|
controlnet_config["operations"] = control_lora_ops
|
||||||
|
controlnet_config["dtype"] = dtype
|
||||||
|
self.control_model = cldm.ControlNet(**controlnet_config)
|
||||||
self.control_model.to(model_management.get_torch_device())
|
self.control_model.to(model_management.get_torch_device())
|
||||||
diffusion_model = model.diffusion_model
|
diffusion_model = model.diffusion_model
|
||||||
sd = diffusion_model.state_dict()
|
sd = diffusion_model.state_dict()
|
||||||
cm = self.control_model.state_dict()
|
cm = self.control_model.state_dict()
|
||||||
|
|
||||||
for k in sd:
|
for k in sd:
|
||||||
weight = model_management.resolve_lowvram_weight(sd[k], diffusion_model, k)
|
weight = sd[k]
|
||||||
try:
|
try:
|
||||||
utils.set_attr(self.control_model, k, weight)
|
utils.set_attr(self.control_model, k, weight)
|
||||||
except:
|
except:
|
||||||
@ -367,6 +381,10 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if controlnet_config is None:
|
if controlnet_config is None:
|
||||||
unet_dtype = model_management.unet_dtype()
|
unet_dtype = model_management.unet_dtype()
|
||||||
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
||||||
|
load_device = model_management.get_torch_device()
|
||||||
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
|
if manual_cast_dtype is not None:
|
||||||
|
controlnet_config["operations"] = ops.manual_cast
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||||
control_model = cldm.ControlNet(**controlnet_config)
|
control_model = cldm.ControlNet(**controlnet_config)
|
||||||
@ -395,14 +413,12 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
||||||
print(missing, unexpected)
|
print(missing, unexpected)
|
||||||
|
|
||||||
control_model = control_model.to(unet_dtype)
|
|
||||||
|
|
||||||
global_average_pooling = False
|
global_average_pooling = False
|
||||||
filename = os.path.splitext(ckpt_path)[0]
|
filename = os.path.splitext(ckpt_path)[0]
|
||||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||||
global_average_pooling = True
|
global_average_pooling = True
|
||||||
|
|
||||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
|
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
class T2IAdapter(ControlBase):
|
class T2IAdapter(ControlBase):
|
||||||
|
|||||||
@ -33,3 +33,7 @@ class SDXL(LatentFormat):
|
|||||||
[-0.3112, -0.2359, -0.2076]
|
[-0.3112, -0.2359, -0.2076]
|
||||||
]
|
]
|
||||||
self.taesd_decoder_name = "taesdxl_decoder"
|
self.taesd_decoder_name = "taesdxl_decoder"
|
||||||
|
|
||||||
|
class SD_X4(LatentFormat):
|
||||||
|
def __init__(self):
|
||||||
|
self.scale_factor = 0.08333
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from ..modules.distributions.distributions import DiagonalGaussianDistribution
|
|||||||
|
|
||||||
from ..util import instantiate_from_config, get_obj_from_str
|
from ..util import instantiate_from_config, get_obj_from_str
|
||||||
from ..modules.ema import LitEma
|
from ..modules.ema import LitEma
|
||||||
|
from ... import ops
|
||||||
|
|
||||||
class DiagonalGaussianRegularizer(torch.nn.Module):
|
class DiagonalGaussianRegularizer(torch.nn.Module):
|
||||||
def __init__(self, sample: bool = True):
|
def __init__(self, sample: bool = True):
|
||||||
@ -162,12 +163,12 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
|||||||
},
|
},
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.quant_conv = torch.nn.Conv2d(
|
self.quant_conv = ops.disable_weight_init.Conv2d(
|
||||||
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
||||||
(1 + ddconfig["double_z"]) * embed_dim,
|
(1 + ddconfig["double_z"]) * embed_dim,
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
self.post_quant_conv = ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
def get_autoencoder_params(self) -> list:
|
def get_autoencoder_params(self) -> list:
|
||||||
|
|||||||
@ -18,6 +18,7 @@ if model_management.xformers_enabled():
|
|||||||
|
|
||||||
from ...cli_args import args
|
from ...cli_args import args
|
||||||
from ... import ops
|
from ... import ops
|
||||||
|
ops = ops.disable_weight_init
|
||||||
|
|
||||||
# CrossAttn precision handling
|
# CrossAttn precision handling
|
||||||
if args.dont_upcast_attention:
|
if args.dont_upcast_attention:
|
||||||
@ -82,16 +83,6 @@ class FeedForward(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x)
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
def zero_module(module):
|
|
||||||
"""
|
|
||||||
Zero out the parameters of a module and return it.
|
|
||||||
"""
|
|
||||||
for p in module.parameters():
|
|
||||||
p.detach().zero_()
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels, dtype=None, device=None):
|
def Normalize(in_channels, dtype=None, device=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
@ -112,19 +103,20 @@ def attention_basic(q, k, v, heads, mask=None):
|
|||||||
|
|
||||||
# force cast to fp32 to avoid overflowing
|
# force cast to fp32 to avoid overflowing
|
||||||
if _ATTN_PRECISION =="fp32":
|
if _ATTN_PRECISION =="fp32":
|
||||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
||||||
q, k = q.float(), k.float()
|
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
|
||||||
else:
|
else:
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||||
|
|
||||||
del q, k
|
del q, k
|
||||||
|
|
||||||
if exists(mask):
|
if exists(mask):
|
||||||
mask = rearrange(mask, 'b ... -> b (...)')
|
if mask.dtype == torch.bool:
|
||||||
max_neg_value = -torch.finfo(sim.dtype).max
|
mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
|
||||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
sim.masked_fill_(~mask, max_neg_value)
|
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||||
|
sim.masked_fill_(~mask, max_neg_value)
|
||||||
|
else:
|
||||||
|
sim += mask
|
||||||
|
|
||||||
# attention, what we cannot get enough of
|
# attention, what we cannot get enough of
|
||||||
sim = sim.softmax(dim=-1)
|
sim = sim.softmax(dim=-1)
|
||||||
@ -349,6 +341,18 @@ else:
|
|||||||
if model_management.pytorch_attention_enabled():
|
if model_management.pytorch_attention_enabled():
|
||||||
optimized_attention_masked = attention_pytorch
|
optimized_attention_masked = attention_pytorch
|
||||||
|
|
||||||
|
def optimized_attention_for_device(device, mask=False):
|
||||||
|
if device == torch.device("cpu"): #TODO
|
||||||
|
if model_management.pytorch_attention_enabled():
|
||||||
|
return attention_pytorch
|
||||||
|
else:
|
||||||
|
return attention_basic
|
||||||
|
if mask:
|
||||||
|
return optimized_attention_masked
|
||||||
|
|
||||||
|
return optimized_attention
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -393,7 +397,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
self.is_res = inner_dim == dim
|
self.is_res = inner_dim == dim
|
||||||
|
|
||||||
if self.ff_in:
|
if self.ff_in:
|
||||||
self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device)
|
self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||||
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.disable_self_attn = disable_self_attn
|
self.disable_self_attn = disable_self_attn
|
||||||
@ -413,10 +417,10 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
||||||
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||||
self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||||
self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
self.d_head = d_head
|
self.d_head = d_head
|
||||||
@ -558,7 +562,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
context_dim = [context_dim] * depth
|
context_dim = [context_dim] * depth
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
inner_dim = n_heads * d_head
|
inner_dim = n_heads * d_head
|
||||||
self.norm = Normalize(in_channels, dtype=dtype, device=device)
|
self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
if not use_linear:
|
if not use_linear:
|
||||||
self.proj_in = operations.Conv2d(in_channels,
|
self.proj_in = operations.Conv2d(in_channels,
|
||||||
inner_dim,
|
inner_dim,
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from typing import Optional, Any
|
|||||||
|
|
||||||
from .... import model_management
|
from .... import model_management
|
||||||
from .... import ops
|
from .... import ops
|
||||||
|
ops = ops.disable_weight_init
|
||||||
|
|
||||||
if model_management.xformers_enabled_vae():
|
if model_management.xformers_enabled_vae():
|
||||||
import xformers
|
import xformers
|
||||||
@ -40,7 +41,7 @@ def nonlinearity(x):
|
|||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels, num_groups=32):
|
def Normalize(in_channels, num_groups=32):
|
||||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
|
||||||
|
|
||||||
class Upsample(nn.Module):
|
class Upsample(nn.Module):
|
||||||
|
|||||||
@ -12,13 +12,13 @@ from .util import (
|
|||||||
checkpoint,
|
checkpoint,
|
||||||
avg_pool_nd,
|
avg_pool_nd,
|
||||||
zero_module,
|
zero_module,
|
||||||
normalization,
|
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
AlphaBlender,
|
AlphaBlender,
|
||||||
)
|
)
|
||||||
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
|
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
|
||||||
from ...util import exists
|
from ...util import exists
|
||||||
from .... import ops
|
from .... import ops
|
||||||
|
ops = ops.disable_weight_init
|
||||||
|
|
||||||
class TimestepBlock(nn.Module):
|
class TimestepBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -177,7 +177,7 @@ class ResBlock(TimestepBlock):
|
|||||||
padding = kernel_size // 2
|
padding = kernel_size // 2
|
||||||
|
|
||||||
self.in_layers = nn.Sequential(
|
self.in_layers = nn.Sequential(
|
||||||
nn.GroupNorm(32, channels, dtype=dtype, device=device),
|
operations.GroupNorm(32, channels, dtype=dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
|
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
@ -206,12 +206,11 @@ class ResBlock(TimestepBlock):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.out_layers = nn.Sequential(
|
self.out_layers = nn.Sequential(
|
||||||
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Dropout(p=dropout),
|
nn.Dropout(p=dropout),
|
||||||
zero_module(
|
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
|
||||||
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
|
,
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.out_channels == channels:
|
if self.out_channels == channels:
|
||||||
@ -438,9 +437,6 @@ class UNetModel(nn.Module):
|
|||||||
operations=ops,
|
operations=ops,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
|
||||||
if use_spatial_transformer:
|
|
||||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
|
||||||
|
|
||||||
if context_dim is not None:
|
if context_dim is not None:
|
||||||
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||||
@ -457,7 +453,6 @@ class UNetModel(nn.Module):
|
|||||||
if num_head_channels == -1:
|
if num_head_channels == -1:
|
||||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||||
|
|
||||||
self.image_size = image_size
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.model_channels = model_channels
|
self.model_channels = model_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
@ -503,7 +498,7 @@ class UNetModel(nn.Module):
|
|||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
if isinstance(self.num_classes, int):
|
if isinstance(self.num_classes, int):
|
||||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
|
||||||
elif self.num_classes == "continuous":
|
elif self.num_classes == "continuous":
|
||||||
print("setting up linear c_adm embedding layer")
|
print("setting up linear c_adm embedding layer")
|
||||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||||
@ -810,13 +805,13 @@ class UNetModel(nn.Module):
|
|||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
|
zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
|
||||||
)
|
)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
self.id_predictor = nn.Sequential(
|
self.id_predictor = nn.Sequential(
|
||||||
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||||
operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
|
operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
|
||||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||||
)
|
)
|
||||||
@ -842,14 +837,14 @@ class UNetModel(nn.Module):
|
|||||||
self.num_classes is not None
|
self.num_classes is not None
|
||||||
), "must specify y if and only if the model is class-conditional"
|
), "must specify y if and only if the model is class-conditional"
|
||||||
hs = []
|
hs = []
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
assert y.shape[0] == x.shape[0]
|
assert y.shape[0] == x.shape[0]
|
||||||
emb = emb + self.label_emb(y)
|
emb = emb + self.label_emb(y)
|
||||||
|
|
||||||
h = x.type(self.dtype)
|
h = x
|
||||||
for id, module in enumerate(self.input_blocks):
|
for id, module in enumerate(self.input_blocks):
|
||||||
transformer_options["block"] = ("input", id)
|
transformer_options["block"] = ("input", id)
|
||||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||||
|
|||||||
@ -41,10 +41,14 @@ class AbstractLowScaleModel(nn.Module):
|
|||||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
||||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
||||||
|
|
||||||
def q_sample(self, x_start, t, noise=None):
|
def q_sample(self, x_start, t, noise=None, seed=None):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
if noise is None:
|
||||||
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
if seed is None:
|
||||||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
noise = torch.randn_like(x_start)
|
||||||
|
else:
|
||||||
|
noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device)
|
||||||
|
return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
|
||||||
|
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x, None
|
return x, None
|
||||||
@ -69,12 +73,12 @@ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
|
|||||||
super().__init__(noise_schedule_config=noise_schedule_config)
|
super().__init__(noise_schedule_config=noise_schedule_config)
|
||||||
self.max_noise_level = max_noise_level
|
self.max_noise_level = max_noise_level
|
||||||
|
|
||||||
def forward(self, x, noise_level=None):
|
def forward(self, x, noise_level=None, seed=None):
|
||||||
if noise_level is None:
|
if noise_level is None:
|
||||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||||
else:
|
else:
|
||||||
assert isinstance(noise_level, torch.Tensor)
|
assert isinstance(noise_level, torch.Tensor)
|
||||||
z = self.q_sample(x, noise_level)
|
z = self.q_sample(x, noise_level, seed=seed)
|
||||||
return z, noise_level
|
return z, noise_level
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -16,7 +16,6 @@ import numpy as np
|
|||||||
from einops import repeat, rearrange
|
from einops import repeat, rearrange
|
||||||
|
|
||||||
from ...util import instantiate_from_config
|
from ...util import instantiate_from_config
|
||||||
from .... import ops
|
|
||||||
|
|
||||||
class AlphaBlender(nn.Module):
|
class AlphaBlender(nn.Module):
|
||||||
strategies = ["learned", "fixed", "learned_with_images"]
|
strategies = ["learned", "fixed", "learned_with_images"]
|
||||||
@ -52,9 +51,9 @@ class AlphaBlender(nn.Module):
|
|||||||
if self.merge_strategy == "fixed":
|
if self.merge_strategy == "fixed":
|
||||||
# make shape compatible
|
# make shape compatible
|
||||||
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
|
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
|
||||||
alpha = self.mix_factor
|
alpha = self.mix_factor.to(image_only_indicator.device)
|
||||||
elif self.merge_strategy == "learned":
|
elif self.merge_strategy == "learned":
|
||||||
alpha = torch.sigmoid(self.mix_factor)
|
alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device))
|
||||||
# make shape compatible
|
# make shape compatible
|
||||||
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
|
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
|
||||||
elif self.merge_strategy == "learned_with_images":
|
elif self.merge_strategy == "learned_with_images":
|
||||||
@ -62,7 +61,7 @@ class AlphaBlender(nn.Module):
|
|||||||
alpha = torch.where(
|
alpha = torch.where(
|
||||||
image_only_indicator.bool(),
|
image_only_indicator.bool(),
|
||||||
torch.ones(1, 1, device=image_only_indicator.device),
|
torch.ones(1, 1, device=image_only_indicator.device),
|
||||||
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
|
rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"),
|
||||||
)
|
)
|
||||||
alpha = rearrange(alpha, self.rearrange_pattern)
|
alpha = rearrange(alpha, self.rearrange_pattern)
|
||||||
# make shape compatible
|
# make shape compatible
|
||||||
@ -273,46 +272,6 @@ def mean_flat(tensor):
|
|||||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||||
|
|
||||||
|
|
||||||
def normalization(channels, dtype=None):
|
|
||||||
"""
|
|
||||||
Make a standard normalization layer.
|
|
||||||
:param channels: number of input channels.
|
|
||||||
:return: an nn.Module for normalization.
|
|
||||||
"""
|
|
||||||
return GroupNorm32(32, channels, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
|
||||||
class SiLU(nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
return x * torch.sigmoid(x)
|
|
||||||
|
|
||||||
|
|
||||||
class GroupNorm32(nn.GroupNorm):
|
|
||||||
def forward(self, x):
|
|
||||||
return super().forward(x.float()).type(x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def conv_nd(dims, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Create a 1D, 2D, or 3D convolution module.
|
|
||||||
"""
|
|
||||||
if dims == 1:
|
|
||||||
return nn.Conv1d(*args, **kwargs)
|
|
||||||
elif dims == 2:
|
|
||||||
return ops.Conv2d(*args, **kwargs)
|
|
||||||
elif dims == 3:
|
|
||||||
return nn.Conv3d(*args, **kwargs)
|
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
|
||||||
|
|
||||||
|
|
||||||
def linear(*args, **kwargs):
|
|
||||||
"""
|
|
||||||
Create a linear module.
|
|
||||||
"""
|
|
||||||
return ops.Linear(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def avg_pool_nd(dims, *args, **kwargs):
|
def avg_pool_nd(dims, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create a 1D, 2D, or 3D average pooling module.
|
Create a 1D, 2D, or 3D average pooling module.
|
||||||
|
|||||||
@ -15,12 +15,12 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
|
|||||||
|
|
||||||
def scale(self, x):
|
def scale(self, x):
|
||||||
# re-normalize to centered mean and unit variance
|
# re-normalize to centered mean and unit variance
|
||||||
x = (x - self.data_mean) * 1. / self.data_std
|
x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def unscale(self, x):
|
def unscale(self, x):
|
||||||
# back to original data stats
|
# back to original data stats
|
||||||
x = (x * self.data_std) + self.data_mean
|
x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x, noise_level=None):
|
def forward(self, x, noise_level=None):
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import torch
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
from ... import ops
|
from ... import ops
|
||||||
|
ops = ops.disable_weight_init
|
||||||
|
|
||||||
from .diffusionmodules.model import (
|
from .diffusionmodules.model import (
|
||||||
AttnBlock,
|
AttnBlock,
|
||||||
@ -81,14 +82,14 @@ class VideoResBlock(ResnetBlock):
|
|||||||
|
|
||||||
x = self.time_stack(x, temb)
|
x = self.time_stack(x, temb)
|
||||||
|
|
||||||
alpha = self.get_alpha(bs=b // timesteps)
|
alpha = self.get_alpha(bs=b // timesteps).to(x.device)
|
||||||
x = alpha * x + (1.0 - alpha) * x_mix
|
x = alpha * x + (1.0 - alpha) * x_mix
|
||||||
|
|
||||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class AE3DConv(torch.nn.Conv2d):
|
class AE3DConv(ops.Conv2d):
|
||||||
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
|
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
|
||||||
super().__init__(in_channels, out_channels, *args, **kwargs)
|
super().__init__(in_channels, out_channels, *args, **kwargs)
|
||||||
if isinstance(video_kernel_size, Iterable):
|
if isinstance(video_kernel_size, Iterable):
|
||||||
@ -96,7 +97,7 @@ class AE3DConv(torch.nn.Conv2d):
|
|||||||
else:
|
else:
|
||||||
padding = int(video_kernel_size // 2)
|
padding = int(video_kernel_size // 2)
|
||||||
|
|
||||||
self.time_mix_conv = torch.nn.Conv3d(
|
self.time_mix_conv = ops.Conv3d(
|
||||||
in_channels=out_channels,
|
in_channels=out_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
kernel_size=video_kernel_size,
|
kernel_size=video_kernel_size,
|
||||||
@ -166,7 +167,7 @@ class AttnVideoBlock(AttnBlock):
|
|||||||
emb = emb[:, None, :]
|
emb = emb[:, None, :]
|
||||||
x_mix = x_mix + emb
|
x_mix = x_mix + emb
|
||||||
|
|
||||||
alpha = self.get_alpha()
|
alpha = self.get_alpha().to(x.device)
|
||||||
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
|
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
|
||||||
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
|
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
|
||||||
|
|
||||||
|
|||||||
@ -43,7 +43,7 @@ def load_lora(lora, to_load):
|
|||||||
if mid_name is not None and mid_name in lora.keys():
|
if mid_name is not None and mid_name in lora.keys():
|
||||||
mid = lora[mid_name]
|
mid = lora[mid_name]
|
||||||
loaded_keys.add(mid_name)
|
loaded_keys.add(mid_name)
|
||||||
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
|
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
|
||||||
loaded_keys.add(A_name)
|
loaded_keys.add(A_name)
|
||||||
loaded_keys.add(B_name)
|
loaded_keys.add(B_name)
|
||||||
|
|
||||||
@ -64,7 +64,7 @@ def load_lora(lora, to_load):
|
|||||||
loaded_keys.add(hada_t1_name)
|
loaded_keys.add(hada_t1_name)
|
||||||
loaded_keys.add(hada_t2_name)
|
loaded_keys.add(hada_t2_name)
|
||||||
|
|
||||||
patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)
|
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
|
||||||
loaded_keys.add(hada_w1_a_name)
|
loaded_keys.add(hada_w1_a_name)
|
||||||
loaded_keys.add(hada_w1_b_name)
|
loaded_keys.add(hada_w1_b_name)
|
||||||
loaded_keys.add(hada_w2_a_name)
|
loaded_keys.add(hada_w2_a_name)
|
||||||
@ -116,8 +116,19 @@ def load_lora(lora, to_load):
|
|||||||
loaded_keys.add(lokr_t2_name)
|
loaded_keys.add(lokr_t2_name)
|
||||||
|
|
||||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||||
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
|
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
|
||||||
|
|
||||||
|
#glora
|
||||||
|
a1_name = "{}.a1.weight".format(x)
|
||||||
|
a2_name = "{}.a2.weight".format(x)
|
||||||
|
b1_name = "{}.b1.weight".format(x)
|
||||||
|
b2_name = "{}.b2.weight".format(x)
|
||||||
|
if a1_name in lora:
|
||||||
|
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
|
||||||
|
loaded_keys.add(a1_name)
|
||||||
|
loaded_keys.add(a2_name)
|
||||||
|
loaded_keys.add(b1_name)
|
||||||
|
loaded_keys.add(b2_name)
|
||||||
|
|
||||||
w_norm_name = "{}.w_norm".format(x)
|
w_norm_name = "{}.w_norm".format(x)
|
||||||
b_norm_name = "{}.b_norm".format(x)
|
b_norm_name = "{}.b_norm".format(x)
|
||||||
@ -126,21 +137,21 @@ def load_lora(lora, to_load):
|
|||||||
|
|
||||||
if w_norm is not None:
|
if w_norm is not None:
|
||||||
loaded_keys.add(w_norm_name)
|
loaded_keys.add(w_norm_name)
|
||||||
patch_dict[to_load[x]] = (w_norm,)
|
patch_dict[to_load[x]] = ("diff", (w_norm,))
|
||||||
if b_norm is not None:
|
if b_norm is not None:
|
||||||
loaded_keys.add(b_norm_name)
|
loaded_keys.add(b_norm_name)
|
||||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
|
||||||
|
|
||||||
diff_name = "{}.diff".format(x)
|
diff_name = "{}.diff".format(x)
|
||||||
diff_weight = lora.get(diff_name, None)
|
diff_weight = lora.get(diff_name, None)
|
||||||
if diff_weight is not None:
|
if diff_weight is not None:
|
||||||
patch_dict[to_load[x]] = (diff_weight,)
|
patch_dict[to_load[x]] = ("diff", (diff_weight,))
|
||||||
loaded_keys.add(diff_name)
|
loaded_keys.add(diff_name)
|
||||||
|
|
||||||
diff_bias_name = "{}.diff_b".format(x)
|
diff_bias_name = "{}.diff_b".format(x)
|
||||||
diff_bias = lora.get(diff_bias_name, None)
|
diff_bias = lora.get(diff_bias_name, None)
|
||||||
if diff_bias is not None:
|
if diff_bias is not None:
|
||||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
|
||||||
loaded_keys.add(diff_bias_name)
|
loaded_keys.add(diff_bias_name)
|
||||||
|
|
||||||
for x in lora.keys():
|
for x in lora.keys():
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
from .ldm.modules.diffusionmodules.openaimodel import UNetModel
|
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||||
from .ldm.modules.diffusionmodules.openaimodel import Timestep
|
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||||
from . import model_management
|
from . import model_management
|
||||||
from . import conds
|
from . import conds
|
||||||
|
from . import ops
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
import contextlib
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
||||||
class ModelType(Enum):
|
class ModelType(Enum):
|
||||||
@ -13,7 +15,7 @@ class ModelType(Enum):
|
|||||||
V_PREDICTION_EDM = 3
|
V_PREDICTION_EDM = 3
|
||||||
|
|
||||||
|
|
||||||
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
|
from .model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
|
||||||
|
|
||||||
|
|
||||||
def model_sampling(model_config, model_type):
|
def model_sampling(model_config, model_type):
|
||||||
@ -40,9 +42,14 @@ class BaseModel(torch.nn.Module):
|
|||||||
unet_config = model_config.unet_config
|
unet_config = model_config.unet_config
|
||||||
self.latent_format = model_config.latent_format
|
self.latent_format = model_config.latent_format
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
|
self.manual_cast_dtype = model_config.manual_cast_dtype
|
||||||
|
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
self.diffusion_model = UNetModel(**unet_config, device=device)
|
if self.manual_cast_dtype is not None:
|
||||||
|
operations = ops.manual_cast
|
||||||
|
else:
|
||||||
|
operations = ops.disable_weight_init
|
||||||
|
self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.model_sampling = model_sampling(model_config, model_type)
|
self.model_sampling = model_sampling(model_config, model_type)
|
||||||
|
|
||||||
@ -61,15 +68,21 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
context = c_crossattn
|
context = c_crossattn
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype()
|
||||||
|
|
||||||
|
if self.manual_cast_dtype is not None:
|
||||||
|
dtype = self.manual_cast_dtype
|
||||||
|
|
||||||
xc = xc.to(dtype)
|
xc = xc.to(dtype)
|
||||||
t = self.model_sampling.timestep(t).float()
|
t = self.model_sampling.timestep(t).float()
|
||||||
context = context.to(dtype)
|
context = context.to(dtype)
|
||||||
extra_conds = {}
|
extra_conds = {}
|
||||||
for o in kwargs:
|
for o in kwargs:
|
||||||
extra = kwargs[o]
|
extra = kwargs[o]
|
||||||
if hasattr(extra, "to"):
|
if hasattr(extra, "dtype"):
|
||||||
extra = extra.to(dtype)
|
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||||
|
extra = extra.to(dtype)
|
||||||
extra_conds[o] = extra
|
extra_conds[o] = extra
|
||||||
|
|
||||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||||
|
|
||||||
@ -117,6 +130,10 @@ class BaseModel(torch.nn.Module):
|
|||||||
adm = self.encode_adm(**kwargs)
|
adm = self.encode_adm(**kwargs)
|
||||||
if adm is not None:
|
if adm is not None:
|
||||||
out['y'] = conds.CONDRegular(adm)
|
out['y'] = conds.CONDRegular(adm)
|
||||||
|
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def load_model_weights(self, sd, unet_prefix=""):
|
def load_model_weights(self, sd, unet_prefix=""):
|
||||||
@ -144,11 +161,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
|
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
|
||||||
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
|
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
|
||||||
unet_sd = self.diffusion_model.state_dict()
|
unet_state_dict = self.diffusion_model.state_dict()
|
||||||
unet_state_dict = {}
|
|
||||||
for k in unet_sd:
|
|
||||||
unet_state_dict[k] = model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
|
|
||||||
|
|
||||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
|
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
|
||||||
if self.get_dtype() == torch.float16:
|
if self.get_dtype() == torch.float16:
|
||||||
@ -165,9 +178,12 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
def memory_required(self, input_shape):
|
def memory_required(self, input_shape):
|
||||||
if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention():
|
if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention():
|
||||||
|
dtype = self.get_dtype()
|
||||||
|
if self.manual_cast_dtype is not None:
|
||||||
|
dtype = self.manual_cast_dtype
|
||||||
#TODO: this needs to be tweaked
|
#TODO: this needs to be tweaked
|
||||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||||
return (area * model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024)
|
return (area * model_management.dtype_size(dtype) / 50) * (1024 * 1024)
|
||||||
else:
|
else:
|
||||||
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||||
area = input_shape[0] * input_shape[2] * input_shape[3]
|
area = input_shape[0] * input_shape[2] * input_shape[3]
|
||||||
@ -307,9 +323,75 @@ class SVD_img2vid(BaseModel):
|
|||||||
|
|
||||||
out['c_concat'] = conds.CONDNoiseShape(latent_image)
|
out['c_concat'] = conds.CONDNoiseShape(latent_image)
|
||||||
|
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
|
||||||
|
|
||||||
if "time_conditioning" in kwargs:
|
if "time_conditioning" in kwargs:
|
||||||
out["time_context"] = conds.CONDCrossAttn(kwargs["time_conditioning"])
|
out["time_context"] = conds.CONDCrossAttn(kwargs["time_conditioning"])
|
||||||
|
|
||||||
out['image_only_indicator'] = conds.CONDConstant(torch.zeros((1,), device=device))
|
out['image_only_indicator'] = conds.CONDConstant(torch.zeros((1,), device=device))
|
||||||
out['num_video_frames'] = conds.CONDConstant(noise.shape[0])
|
out['num_video_frames'] = conds.CONDConstant(noise.shape[0])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class Stable_Zero123(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
|
||||||
|
super().__init__(model_config, model_type, device=device)
|
||||||
|
self.cc_projection = ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device)
|
||||||
|
self.cc_projection.weight.copy_(cc_projection_weight)
|
||||||
|
self.cc_projection.bias.copy_(cc_projection_bias)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
|
||||||
|
latent_image = kwargs.get("concat_latent_image", None)
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
|
||||||
|
if latent_image is None:
|
||||||
|
latent_image = torch.zeros_like(noise)
|
||||||
|
|
||||||
|
if latent_image.shape[1:] != noise.shape[1:]:
|
||||||
|
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
|
|
||||||
|
latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])
|
||||||
|
|
||||||
|
out['c_concat'] = conds.CONDNoiseShape(latent_image)
|
||||||
|
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
if cross_attn.shape[-1] != 768:
|
||||||
|
cross_attn = self.cc_projection(cross_attn)
|
||||||
|
out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class SD_X4Upscaler(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device)
|
||||||
|
self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
|
||||||
|
image = kwargs.get("concat_image", None)
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
noise_augment = kwargs.get("noise_augmentation", 0.0)
|
||||||
|
device = kwargs["device"]
|
||||||
|
seed = kwargs["seed"] - 10
|
||||||
|
|
||||||
|
noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment)
|
||||||
|
|
||||||
|
if image is None:
|
||||||
|
image = torch.zeros_like(noise)[:,:3]
|
||||||
|
|
||||||
|
if image.shape[1:] != noise.shape[1:]:
|
||||||
|
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
|
|
||||||
|
noise_level = torch.tensor([noise_level], device=device)
|
||||||
|
if noise_augment > 0:
|
||||||
|
image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed)
|
||||||
|
|
||||||
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||||
|
|
||||||
|
out['c_concat'] = conds.CONDNoiseShape(image)
|
||||||
|
out['y'] = conds.CONDRegular(noise_level)
|
||||||
|
return out
|
||||||
|
|||||||
@ -34,7 +34,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
|||||||
unet_config = {
|
unet_config = {
|
||||||
"use_checkpoint": False,
|
"use_checkpoint": False,
|
||||||
"image_size": 32,
|
"image_size": 32,
|
||||||
"out_channels": 4,
|
|
||||||
"use_spatial_transformer": True,
|
"use_spatial_transformer": True,
|
||||||
"legacy": False
|
"legacy": False
|
||||||
}
|
}
|
||||||
@ -50,6 +49,12 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
|||||||
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
||||||
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
||||||
|
|
||||||
|
out_key = '{}out.2.weight'.format(key_prefix)
|
||||||
|
if out_key in state_dict:
|
||||||
|
out_channels = state_dict[out_key].shape[0]
|
||||||
|
else:
|
||||||
|
out_channels = 4
|
||||||
|
|
||||||
num_res_blocks = []
|
num_res_blocks = []
|
||||||
channel_mult = []
|
channel_mult = []
|
||||||
attention_resolutions = []
|
attention_resolutions = []
|
||||||
@ -122,6 +127,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
|||||||
transformer_depth_middle = -1
|
transformer_depth_middle = -1
|
||||||
|
|
||||||
unet_config["in_channels"] = in_channels
|
unet_config["in_channels"] = in_channels
|
||||||
|
unet_config["out_channels"] = out_channels
|
||||||
unet_config["model_channels"] = model_channels
|
unet_config["model_channels"] = model_channels
|
||||||
unet_config["num_res_blocks"] = num_res_blocks
|
unet_config["num_res_blocks"] = num_res_blocks
|
||||||
unet_config["transformer_depth"] = transformer_depth
|
unet_config["transformer_depth"] = transformer_depth
|
||||||
@ -289,7 +295,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
|
|||||||
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B]
|
Segmind_Vega = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||||
|
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 1, 1, 2, 2], 'transformer_depth_output': [0, 0, 0, 1, 1, 1, 2, 2, 2],
|
||||||
|
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||||
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
|
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega]
|
||||||
|
|
||||||
for unet_config in supported_models:
|
for unet_config in supported_models:
|
||||||
matches = True
|
matches = True
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import psutil
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from .cli_args import args
|
from .cli_args import args
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@ -28,6 +29,10 @@ total_vram = 0
|
|||||||
lowvram_available = True
|
lowvram_available = True
|
||||||
xpu_available = False
|
xpu_available = False
|
||||||
|
|
||||||
|
if args.deterministic:
|
||||||
|
print("Using deterministic algorithms for pytorch")
|
||||||
|
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||||
|
|
||||||
directml_enabled = False
|
directml_enabled = False
|
||||||
if args.directml is not None:
|
if args.directml is not None:
|
||||||
import torch_directml
|
import torch_directml
|
||||||
@ -182,6 +187,9 @@ except:
|
|||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
VAE_DTYPE = torch.bfloat16
|
VAE_DTYPE = torch.bfloat16
|
||||||
|
|
||||||
|
if args.cpu_vae:
|
||||||
|
VAE_DTYPE = torch.float32
|
||||||
|
|
||||||
if args.fp16_vae:
|
if args.fp16_vae:
|
||||||
VAE_DTYPE = torch.float16
|
VAE_DTYPE = torch.float16
|
||||||
elif args.bf16_vae:
|
elif args.bf16_vae:
|
||||||
@ -214,15 +222,8 @@ if args.force_fp16 or cpu_state == CPUState.MPS:
|
|||||||
FORCE_FP16 = True
|
FORCE_FP16 = True
|
||||||
|
|
||||||
if lowvram_available:
|
if lowvram_available:
|
||||||
try:
|
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
||||||
import accelerate
|
vram_state = set_vram_to
|
||||||
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
|
|
||||||
vram_state = set_vram_to
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
print(traceback.format_exc())
|
|
||||||
print("ERROR: LOW VRAM MODE NEEDS accelerate.")
|
|
||||||
lowvram_available = False
|
|
||||||
|
|
||||||
|
|
||||||
if cpu_state != CPUState.GPU:
|
if cpu_state != CPUState.GPU:
|
||||||
@ -262,6 +263,14 @@ print("VAE dtype:", VAE_DTYPE)
|
|||||||
|
|
||||||
current_loaded_models = []
|
current_loaded_models = []
|
||||||
|
|
||||||
|
def module_size(module):
|
||||||
|
module_mem = 0
|
||||||
|
sd = module.state_dict()
|
||||||
|
for k in sd:
|
||||||
|
t = sd[k]
|
||||||
|
module_mem += t.nelement() * t.element_size()
|
||||||
|
return module_mem
|
||||||
|
|
||||||
class LoadedModel:
|
class LoadedModel:
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -294,8 +303,20 @@ class LoadedModel:
|
|||||||
|
|
||||||
if lowvram_model_memory > 0:
|
if lowvram_model_memory > 0:
|
||||||
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
|
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
|
||||||
device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
|
mem_counter = 0
|
||||||
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
|
for m in self.real_model.modules():
|
||||||
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
|
m.comfy_cast_weights = True
|
||||||
|
module_mem = module_size(m)
|
||||||
|
if mem_counter + module_mem < lowvram_model_memory:
|
||||||
|
m.to(self.device)
|
||||||
|
mem_counter += module_mem
|
||||||
|
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
|
||||||
|
m.to(self.device)
|
||||||
|
mem_counter += module_size(m)
|
||||||
|
print("lowvram: loaded module regularly", m)
|
||||||
|
|
||||||
self.model_accelerated = True
|
self.model_accelerated = True
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||||
@ -305,7 +326,11 @@ class LoadedModel:
|
|||||||
|
|
||||||
def model_unload(self):
|
def model_unload(self):
|
||||||
if self.model_accelerated:
|
if self.model_accelerated:
|
||||||
accelerate.hooks.remove_hook_from_submodules(self.real_model)
|
for m in self.real_model.modules():
|
||||||
|
if hasattr(m, "prev_comfy_cast_weights"):
|
||||||
|
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||||
|
del m.prev_comfy_cast_weights
|
||||||
|
|
||||||
self.model_accelerated = False
|
self.model_accelerated = False
|
||||||
|
|
||||||
self.model.unpatch_model(self.model.offload_device)
|
self.model.unpatch_model(self.model.offload_device)
|
||||||
@ -398,14 +423,14 @@ def load_models_gpu(models, memory_required=0):
|
|||||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||||
model_size = loaded_model.model_memory_required(torch_dev)
|
model_size = loaded_model.model_memory_required(torch_dev)
|
||||||
current_free_mem = get_free_memory(torch_dev)
|
current_free_mem = get_free_memory(torch_dev)
|
||||||
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
|
||||||
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
|
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
|
||||||
vram_set_state = VRAMState.LOW_VRAM
|
vram_set_state = VRAMState.LOW_VRAM
|
||||||
else:
|
else:
|
||||||
lowvram_model_memory = 0
|
lowvram_model_memory = 0
|
||||||
|
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
lowvram_model_memory = 256 * 1024 * 1024
|
lowvram_model_memory = 64 * 1024 * 1024
|
||||||
|
|
||||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
|
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
|
||||||
current_loaded_models.insert(0, loaded_model)
|
current_loaded_models.insert(0, loaded_model)
|
||||||
@ -430,6 +455,13 @@ def dtype_size(dtype):
|
|||||||
dtype_size = 4
|
dtype_size = 4
|
||||||
if dtype == torch.float16 or dtype == torch.bfloat16:
|
if dtype == torch.float16 or dtype == torch.bfloat16:
|
||||||
dtype_size = 2
|
dtype_size = 2
|
||||||
|
elif dtype == torch.float32:
|
||||||
|
dtype_size = 4
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
dtype_size = dtype.itemsize
|
||||||
|
except: #Old pytorch doesn't have .itemsize
|
||||||
|
pass
|
||||||
return dtype_size
|
return dtype_size
|
||||||
|
|
||||||
def unet_offload_device():
|
def unet_offload_device():
|
||||||
@ -459,10 +491,30 @@ def unet_inital_load_device(parameters, dtype):
|
|||||||
def unet_dtype(device=None, model_params=0):
|
def unet_dtype(device=None, model_params=0):
|
||||||
if args.bf16_unet:
|
if args.bf16_unet:
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
|
if args.fp16_unet:
|
||||||
|
return torch.float16
|
||||||
|
if args.fp8_e4m3fn_unet:
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
if args.fp8_e5m2_unet:
|
||||||
|
return torch.float8_e5m2
|
||||||
if should_use_fp16(device=device, model_params=model_params):
|
if should_use_fp16(device=device, model_params=model_params):
|
||||||
return torch.float16
|
return torch.float16
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
# None means no manual cast
|
||||||
|
def unet_manual_cast(weight_dtype, inference_device):
|
||||||
|
if weight_dtype == torch.float32:
|
||||||
|
return None
|
||||||
|
|
||||||
|
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
|
||||||
|
if fp16_supported and weight_dtype == torch.float16:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if fp16_supported:
|
||||||
|
return torch.float16
|
||||||
|
else:
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
def text_encoder_offload_device():
|
def text_encoder_offload_device():
|
||||||
if args.gpu_only:
|
if args.gpu_only:
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
@ -492,12 +544,23 @@ def text_encoder_dtype(device=None):
|
|||||||
elif args.fp32_text_enc:
|
elif args.fp32_text_enc:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
if is_device_cpu(device):
|
||||||
|
return torch.float16
|
||||||
|
|
||||||
if should_use_fp16(device, prioritize_performance=False):
|
if should_use_fp16(device, prioritize_performance=False):
|
||||||
return torch.float16
|
return torch.float16
|
||||||
else:
|
else:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
def intermediate_device():
|
||||||
|
if args.gpu_only:
|
||||||
|
return get_torch_device()
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
def vae_device():
|
def vae_device():
|
||||||
|
if args.cpu_vae:
|
||||||
|
return torch.device("cpu")
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
|
|
||||||
def vae_offload_device():
|
def vae_offload_device():
|
||||||
@ -515,6 +578,22 @@ def get_autocast_device(dev):
|
|||||||
return dev.type
|
return dev.type
|
||||||
return "cuda"
|
return "cuda"
|
||||||
|
|
||||||
|
def supports_dtype(device, dtype): #TODO
|
||||||
|
if dtype == torch.float32:
|
||||||
|
return True
|
||||||
|
if is_device_cpu(device):
|
||||||
|
return False
|
||||||
|
if dtype == torch.float16:
|
||||||
|
return True
|
||||||
|
if dtype == torch.bfloat16:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def device_supports_non_blocking(device):
|
||||||
|
if is_device_mps(device):
|
||||||
|
return False #pytorch bug? mps doesn't support non blocking
|
||||||
|
return True
|
||||||
|
|
||||||
def cast_to_device(tensor, device, dtype, copy=False):
|
def cast_to_device(tensor, device, dtype, copy=False):
|
||||||
device_supports_cast = False
|
device_supports_cast = False
|
||||||
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
||||||
@ -525,15 +604,17 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
|||||||
elif is_intel_xpu():
|
elif is_intel_xpu():
|
||||||
device_supports_cast = True
|
device_supports_cast = True
|
||||||
|
|
||||||
|
non_blocking = device_supports_non_blocking(device)
|
||||||
|
|
||||||
if device_supports_cast:
|
if device_supports_cast:
|
||||||
if copy:
|
if copy:
|
||||||
if tensor.device == device:
|
if tensor.device == device:
|
||||||
return tensor.to(dtype, copy=copy)
|
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
|
||||||
return tensor.to(device, copy=copy).to(dtype)
|
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
||||||
else:
|
else:
|
||||||
return tensor.to(device).to(dtype)
|
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
||||||
else:
|
else:
|
||||||
return tensor.to(dtype).to(device, copy=copy)
|
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
@ -687,11 +768,11 @@ def soft_empty_cache(force=False):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
def resolve_lowvram_weight(weight, model, key):
|
def unload_all_models():
|
||||||
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
|
free_memory(1e30, get_torch_device())
|
||||||
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
|
|
||||||
op = utils.get_attr(model, '.'.join(key_split[:-1]))
|
|
||||||
weight = op._hf_hook.weights_map[key_split[-1]]
|
def resolve_lowvram_weight(weight, model, key): #TODO: remove
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
#TODO: might be cleaner to put this somewhere else
|
#TODO: might be cleaner to put this somewhere else
|
||||||
|
|||||||
@ -28,13 +28,9 @@ class ModelPatcher:
|
|||||||
if self.size > 0:
|
if self.size > 0:
|
||||||
return self.size
|
return self.size
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model.state_dict()
|
||||||
size = 0
|
self.size = model_management.module_size(self.model)
|
||||||
for k in model_sd:
|
|
||||||
t = model_sd[k]
|
|
||||||
size += t.nelement() * t.element_size()
|
|
||||||
self.size = size
|
|
||||||
self.model_keys = set(model_sd.keys())
|
self.model_keys = set(model_sd.keys())
|
||||||
return size
|
return self.size
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
|
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
|
||||||
@ -55,11 +51,18 @@ class ModelPatcher:
|
|||||||
def memory_required(self, input_shape):
|
def memory_required(self, input_shape):
|
||||||
return self.model.memory_required(input_shape=input_shape)
|
return self.model.memory_required(input_shape=input_shape)
|
||||||
|
|
||||||
def set_model_sampler_cfg_function(self, sampler_cfg_function):
|
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
||||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||||
else:
|
else:
|
||||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||||
|
if disable_cfg1_optimization:
|
||||||
|
self.model_options["disable_cfg1_optimization"] = True
|
||||||
|
|
||||||
|
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
||||||
|
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
|
||||||
|
if disable_cfg1_optimization:
|
||||||
|
self.model_options["disable_cfg1_optimization"] = True
|
||||||
|
|
||||||
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
def set_model_unet_function_wrapper(self, unet_wrapper_function):
|
||||||
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
self.model_options["model_function_wrapper"] = unet_wrapper_function
|
||||||
@ -70,13 +73,17 @@ class ModelPatcher:
|
|||||||
to["patches"] = {}
|
to["patches"] = {}
|
||||||
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
to["patches"][name] = to["patches"].get(name, []) + [patch]
|
||||||
|
|
||||||
def set_model_patch_replace(self, patch, name, block_name, number):
|
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
|
||||||
to = self.model_options["transformer_options"]
|
to = self.model_options["transformer_options"]
|
||||||
if "patches_replace" not in to:
|
if "patches_replace" not in to:
|
||||||
to["patches_replace"] = {}
|
to["patches_replace"] = {}
|
||||||
if name not in to["patches_replace"]:
|
if name not in to["patches_replace"]:
|
||||||
to["patches_replace"][name] = {}
|
to["patches_replace"][name] = {}
|
||||||
to["patches_replace"][name][(block_name, number)] = patch
|
if transformer_index is not None:
|
||||||
|
block = (block_name, number, transformer_index)
|
||||||
|
else:
|
||||||
|
block = (block_name, number)
|
||||||
|
to["patches_replace"][name][block] = patch
|
||||||
|
|
||||||
def set_model_attn1_patch(self, patch):
|
def set_model_attn1_patch(self, patch):
|
||||||
self.set_model_patch(patch, "attn1_patch")
|
self.set_model_patch(patch, "attn1_patch")
|
||||||
@ -84,11 +91,11 @@ class ModelPatcher:
|
|||||||
def set_model_attn2_patch(self, patch):
|
def set_model_attn2_patch(self, patch):
|
||||||
self.set_model_patch(patch, "attn2_patch")
|
self.set_model_patch(patch, "attn2_patch")
|
||||||
|
|
||||||
def set_model_attn1_replace(self, patch, block_name, number):
|
def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
|
||||||
self.set_model_patch_replace(patch, "attn1", block_name, number)
|
self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
|
||||||
|
|
||||||
def set_model_attn2_replace(self, patch, block_name, number):
|
def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
|
||||||
self.set_model_patch_replace(patch, "attn2", block_name, number)
|
self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
|
||||||
|
|
||||||
def set_model_attn1_output_patch(self, patch):
|
def set_model_attn1_output_patch(self, patch):
|
||||||
self.set_model_patch(patch, "attn1_output_patch")
|
self.set_model_patch(patch, "attn1_output_patch")
|
||||||
@ -167,40 +174,41 @@ class ModelPatcher:
|
|||||||
sd.pop(k)
|
sd.pop(k)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def patch_model(self, device_to=None):
|
def patch_model(self, device_to=None, patch_weights=True):
|
||||||
for k in self.object_patches:
|
for k in self.object_patches:
|
||||||
old = getattr(self.model, k)
|
old = getattr(self.model, k)
|
||||||
if k not in self.object_patches_backup:
|
if k not in self.object_patches_backup:
|
||||||
self.object_patches_backup[k] = old
|
self.object_patches_backup[k] = old
|
||||||
setattr(self.model, k, self.object_patches[k])
|
setattr(self.model, k, self.object_patches[k])
|
||||||
|
|
||||||
model_sd = self.model_state_dict()
|
if patch_weights:
|
||||||
for key in self.patches:
|
model_sd = self.model_state_dict()
|
||||||
if key not in model_sd:
|
for key in self.patches:
|
||||||
print("could not patch. key doesn't exist in model:", key)
|
if key not in model_sd:
|
||||||
continue
|
print("could not patch. key doesn't exist in model:", key)
|
||||||
|
continue
|
||||||
|
|
||||||
weight = model_sd[key]
|
weight = model_sd[key]
|
||||||
|
|
||||||
inplace_update = self.weight_inplace_update
|
inplace_update = self.weight_inplace_update
|
||||||
|
|
||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
||||||
|
|
||||||
|
if device_to is not None:
|
||||||
|
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||||
|
else:
|
||||||
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
|
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||||
|
if inplace_update:
|
||||||
|
utils.copy_to_param(self.model, key, out_weight)
|
||||||
|
else:
|
||||||
|
utils.set_attr(self.model, key, out_weight)
|
||||||
|
del temp_weight
|
||||||
|
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
self.model.to(device_to)
|
||||||
else:
|
self.current_device = device_to
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
|
||||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
|
||||||
if inplace_update:
|
|
||||||
utils.copy_to_param(self.model, key, out_weight)
|
|
||||||
else:
|
|
||||||
utils.set_attr(self.model, key, out_weight)
|
|
||||||
del temp_weight
|
|
||||||
|
|
||||||
if device_to is not None:
|
|
||||||
self.model.to(device_to)
|
|
||||||
self.current_device = device_to
|
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
@ -217,13 +225,19 @@ class ModelPatcher:
|
|||||||
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
|
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
|
||||||
|
|
||||||
if len(v) == 1:
|
if len(v) == 1:
|
||||||
|
patch_type = "diff"
|
||||||
|
elif len(v) == 2:
|
||||||
|
patch_type = v[0]
|
||||||
|
v = v[1]
|
||||||
|
|
||||||
|
if patch_type == "diff":
|
||||||
w1 = v[0]
|
w1 = v[0]
|
||||||
if alpha != 0.0:
|
if alpha != 0.0:
|
||||||
if w1.shape != weight.shape:
|
if w1.shape != weight.shape:
|
||||||
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||||
else:
|
else:
|
||||||
weight += alpha * model_management.cast_to_device(w1, weight.device, weight.dtype)
|
weight += alpha * model_management.cast_to_device(w1, weight.device, weight.dtype)
|
||||||
elif len(v) == 4: #lora/locon
|
elif patch_type == "lora": #lora/locon
|
||||||
mat1 = model_management.cast_to_device(v[0], weight.device, torch.float32)
|
mat1 = model_management.cast_to_device(v[0], weight.device, torch.float32)
|
||||||
mat2 = model_management.cast_to_device(v[1], weight.device, torch.float32)
|
mat2 = model_management.cast_to_device(v[1], weight.device, torch.float32)
|
||||||
if v[2] is not None:
|
if v[2] is not None:
|
||||||
@ -237,7 +251,7 @@ class ModelPatcher:
|
|||||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("ERROR", key, e)
|
print("ERROR", key, e)
|
||||||
elif len(v) == 8: #lokr
|
elif patch_type == "lokr":
|
||||||
w1 = v[0]
|
w1 = v[0]
|
||||||
w2 = v[1]
|
w2 = v[1]
|
||||||
w1_a = v[3]
|
w1_a = v[3]
|
||||||
@ -276,7 +290,7 @@ class ModelPatcher:
|
|||||||
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("ERROR", key, e)
|
print("ERROR", key, e)
|
||||||
else: #loha
|
elif patch_type == "loha":
|
||||||
w1a = v[0]
|
w1a = v[0]
|
||||||
w1b = v[1]
|
w1b = v[1]
|
||||||
if v[2] is not None:
|
if v[2] is not None:
|
||||||
@ -305,6 +319,18 @@ class ModelPatcher:
|
|||||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("ERROR", key, e)
|
print("ERROR", key, e)
|
||||||
|
elif patch_type == "glora":
|
||||||
|
if v[4] is not None:
|
||||||
|
alpha *= v[4] / v[0].shape[0]
|
||||||
|
|
||||||
|
a1 = model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
|
||||||
|
a2 = model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
|
||||||
|
b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
|
||||||
|
b2 = model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
||||||
|
|
||||||
|
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
|
||||||
|
else:
|
||||||
|
print("patch type not recognized", patch_type, key)
|
||||||
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
|||||||
@ -22,10 +22,17 @@ class V_PREDICTION(EPS):
|
|||||||
class ModelSamplingDiscrete(torch.nn.Module):
|
class ModelSamplingDiscrete(torch.nn.Module):
|
||||||
def __init__(self, model_config=None):
|
def __init__(self, model_config=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
beta_schedule = "linear"
|
|
||||||
if model_config is not None:
|
if model_config is not None:
|
||||||
beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule)
|
sampling_settings = model_config.sampling_settings
|
||||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
else:
|
||||||
|
sampling_settings = {}
|
||||||
|
|
||||||
|
beta_schedule = sampling_settings.get("beta_schedule", "linear")
|
||||||
|
linear_start = sampling_settings.get("linear_start", 0.00085)
|
||||||
|
linear_end = sampling_settings.get("linear_end", 0.012)
|
||||||
|
|
||||||
|
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
|
||||||
self.sigma_data = 1.0
|
self.sigma_data = 1.0
|
||||||
|
|
||||||
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import hashlib
|
|||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps, ImageSequence
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
@ -930,8 +930,8 @@ class GLIGENTextBoxApply:
|
|||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
class EmptyLatentImage:
|
class EmptyLatentImage:
|
||||||
def __init__(self, device="cpu"):
|
def __init__(self):
|
||||||
self.device = device
|
self.device = comfy.model_management.intermediate_device()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -944,7 +944,7 @@ class EmptyLatentImage:
|
|||||||
CATEGORY = "latent"
|
CATEGORY = "latent"
|
||||||
|
|
||||||
def generate(self, width, height, batch_size=1):
|
def generate(self, width, height, batch_size=1):
|
||||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
|
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
|
||||||
return ({"samples":latent}, )
|
return ({"samples":latent}, )
|
||||||
|
|
||||||
|
|
||||||
@ -1395,17 +1395,30 @@ class LoadImage:
|
|||||||
FUNCTION = "load_image"
|
FUNCTION = "load_image"
|
||||||
def load_image(self, image):
|
def load_image(self, image):
|
||||||
image_path = folder_paths.get_annotated_filepath(image)
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
i = Image.open(image_path)
|
img = Image.open(image_path)
|
||||||
i = ImageOps.exif_transpose(i)
|
output_images = []
|
||||||
image = i.convert("RGB")
|
output_masks = []
|
||||||
image = np.array(image).astype(np.float32) / 255.0
|
for i in ImageSequence.Iterator(img):
|
||||||
image = torch.from_numpy(image)[None,]
|
i = ImageOps.exif_transpose(i)
|
||||||
if 'A' in i.getbands():
|
image = i.convert("RGB")
|
||||||
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
mask = 1. - torch.from_numpy(mask)
|
image = torch.from_numpy(image)[None,]
|
||||||
|
if 'A' in i.getbands():
|
||||||
|
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
||||||
|
mask = 1. - torch.from_numpy(mask)
|
||||||
|
else:
|
||||||
|
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||||
|
output_images.append(image)
|
||||||
|
output_masks.append(mask.unsqueeze(0))
|
||||||
|
|
||||||
|
if len(output_images) > 1:
|
||||||
|
output_image = torch.cat(output_images, dim=0)
|
||||||
|
output_mask = torch.cat(output_masks, dim=0)
|
||||||
else:
|
else:
|
||||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
output_image = output_images[0]
|
||||||
return (image, mask.unsqueeze(0))
|
output_mask = output_masks[0]
|
||||||
|
|
||||||
|
return (output_image, output_mask)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(s, image):
|
def IS_CHANGED(s, image):
|
||||||
@ -1463,13 +1476,10 @@ class LoadImageMask:
|
|||||||
return m.digest().hex()
|
return m.digest().hex()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def VALIDATE_INPUTS(s, image, channel):
|
def VALIDATE_INPUTS(s, image):
|
||||||
if not folder_paths.exists_annotated_filepath(image):
|
if not folder_paths.exists_annotated_filepath(image):
|
||||||
return "Invalid image file: {}".format(image)
|
return "Invalid image file: {}".format(image)
|
||||||
|
|
||||||
if channel not in s._color_channels:
|
|
||||||
return "Invalid color channel: {}".format(channel)
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
class ImageScale:
|
class ImageScale:
|
||||||
|
|||||||
139
comfy/ops.py
139
comfy/ops.py
@ -1,40 +1,115 @@
|
|||||||
import torch
|
import torch
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
class Linear(torch.nn.Linear):
|
def cast_bias_weight(s, input):
|
||||||
def reset_parameters(self):
|
bias = None
|
||||||
return None
|
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
||||||
|
if s.bias is not None:
|
||||||
|
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||||
|
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||||
|
return weight, bias
|
||||||
|
|
||||||
class Conv2d(torch.nn.Conv2d):
|
|
||||||
def reset_parameters(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
class Conv3d(torch.nn.Conv3d):
|
class disable_weight_init:
|
||||||
def reset_parameters(self):
|
class Linear(torch.nn.Linear):
|
||||||
return None
|
comfy_cast_weights = False
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
def conv_nd(dims, *args, **kwargs):
|
def forward_comfy_cast_weights(self, input):
|
||||||
if dims == 2:
|
weight, bias = cast_bias_weight(self, input)
|
||||||
return Conv2d(*args, **kwargs)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
elif dims == 3:
|
|
||||||
return Conv3d(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
|
||||||
|
|
||||||
@contextmanager
|
def forward(self, *args, **kwargs):
|
||||||
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
|
if self.comfy_cast_weights:
|
||||||
old_torch_nn_linear = torch.nn.Linear
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
force_device = device
|
else:
|
||||||
force_dtype = dtype
|
return super().forward(*args, **kwargs)
|
||||||
def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
|
||||||
if force_device is not None:
|
|
||||||
device = force_device
|
|
||||||
if force_dtype is not None:
|
|
||||||
dtype = force_dtype
|
|
||||||
return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
torch.nn.Linear = linear_with_dtype
|
class Conv2d(torch.nn.Conv2d):
|
||||||
try:
|
comfy_cast_weights = False
|
||||||
yield
|
def reset_parameters(self):
|
||||||
finally:
|
return None
|
||||||
torch.nn.Linear = old_torch_nn_linear
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.comfy_cast_weights:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
class Conv3d(torch.nn.Conv3d):
|
||||||
|
comfy_cast_weights = False
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.comfy_cast_weights:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
class GroupNorm(torch.nn.GroupNorm):
|
||||||
|
comfy_cast_weights = False
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.comfy_cast_weights:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(torch.nn.LayerNorm):
|
||||||
|
comfy_cast_weights = False
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
weight, bias = cast_bias_weight(self, input)
|
||||||
|
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.comfy_cast_weights:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def conv_nd(s, dims, *args, **kwargs):
|
||||||
|
if dims == 2:
|
||||||
|
return s.Conv2d(*args, **kwargs)
|
||||||
|
elif dims == 3:
|
||||||
|
return s.Conv3d(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
|
||||||
|
|
||||||
|
class manual_cast(disable_weight_init):
|
||||||
|
class Linear(disable_weight_init.Linear):
|
||||||
|
comfy_cast_weights = True
|
||||||
|
|
||||||
|
class Conv2d(disable_weight_init.Conv2d):
|
||||||
|
comfy_cast_weights = True
|
||||||
|
|
||||||
|
class Conv3d(disable_weight_init.Conv3d):
|
||||||
|
comfy_cast_weights = True
|
||||||
|
|
||||||
|
class GroupNorm(disable_weight_init.GroupNorm):
|
||||||
|
comfy_cast_weights = True
|
||||||
|
|
||||||
|
class LayerNorm(disable_weight_init.LayerNorm):
|
||||||
|
comfy_cast_weights = True
|
||||||
|
|||||||
9
comfy/package_data_path_helper.py
Normal file
9
comfy/package_data_path_helper.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from importlib.resources import path
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def get_editable_resource_path(caller_file, *package_path):
|
||||||
|
filename = os.path.join(os.path.dirname(os.path.realpath(caller_file)), package_path[-1])
|
||||||
|
if not os.path.exists(filename):
|
||||||
|
filename = path(*package_path)
|
||||||
|
return filename
|
||||||
@ -47,7 +47,8 @@ def convert_cond(cond):
|
|||||||
temp = c[1].copy()
|
temp = c[1].copy()
|
||||||
model_conds = temp.get("model_conds", {})
|
model_conds = temp.get("model_conds", {})
|
||||||
if c[0] is not None:
|
if c[0] is not None:
|
||||||
model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0])
|
model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) #TODO: remove
|
||||||
|
temp["cross_attn"] = c[0]
|
||||||
temp["model_conds"] = model_conds
|
temp["model_conds"] = model_conds
|
||||||
out.append(temp)
|
out.append(temp)
|
||||||
return out
|
return out
|
||||||
@ -98,10 +99,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
|||||||
sampler = samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
sampler = samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||||
|
|
||||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||||
samples = samples.cpu()
|
samples = samples.to(model_management.intermediate_device())
|
||||||
|
|
||||||
cleanup_additional_models(models)
|
cleanup_additional_models(models)
|
||||||
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
|
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
@ -111,8 +112,8 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent
|
|||||||
sigmas = sigmas.to(model.load_device)
|
sigmas = sigmas.to(model.load_device)
|
||||||
|
|
||||||
samples = samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
samples = samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||||
samples = samples.cpu()
|
samples = samples.to(model_management.intermediate_device())
|
||||||
cleanup_additional_models(models)
|
cleanup_additional_models(models)
|
||||||
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
|
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|||||||
@ -1,256 +1,264 @@
|
|||||||
from .k_diffusion import sampling as k_diffusion_sampling
|
from .k_diffusion import sampling as k_diffusion_sampling
|
||||||
from .extra_samplers import uni_pc
|
from .extra_samplers import uni_pc
|
||||||
import torch
|
import torch
|
||||||
|
import collections
|
||||||
from . import model_management
|
from . import model_management
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
def get_area_and_mult(conds, x_in, timestep_in):
|
||||||
|
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||||
|
strength = 1.0
|
||||||
|
|
||||||
|
if 'timestep_start' in conds:
|
||||||
|
timestep_start = conds['timestep_start']
|
||||||
|
if timestep_in[0] > timestep_start:
|
||||||
|
return None
|
||||||
|
if 'timestep_end' in conds:
|
||||||
|
timestep_end = conds['timestep_end']
|
||||||
|
if timestep_in[0] < timestep_end:
|
||||||
|
return None
|
||||||
|
if 'area' in conds:
|
||||||
|
area = conds['area']
|
||||||
|
if 'strength' in conds:
|
||||||
|
strength = conds['strength']
|
||||||
|
|
||||||
|
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||||
|
if 'mask' in conds:
|
||||||
|
# Scale the mask to the size of the input
|
||||||
|
# The mask should have been resized as we began the sampling process
|
||||||
|
mask_strength = 1.0
|
||||||
|
if "mask_strength" in conds:
|
||||||
|
mask_strength = conds["mask_strength"]
|
||||||
|
mask = conds['mask']
|
||||||
|
assert(mask.shape[1] == x_in.shape[2])
|
||||||
|
assert(mask.shape[2] == x_in.shape[3])
|
||||||
|
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
|
||||||
|
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
||||||
|
else:
|
||||||
|
mask = torch.ones_like(input_x)
|
||||||
|
mult = mask * strength
|
||||||
|
|
||||||
|
if 'mask' not in conds:
|
||||||
|
rr = 8
|
||||||
|
if area[2] != 0:
|
||||||
|
for t in range(rr):
|
||||||
|
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
|
||||||
|
if (area[0] + area[2]) < x_in.shape[2]:
|
||||||
|
for t in range(rr):
|
||||||
|
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
|
||||||
|
if area[3] != 0:
|
||||||
|
for t in range(rr):
|
||||||
|
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
|
||||||
|
if (area[1] + area[3]) < x_in.shape[3]:
|
||||||
|
for t in range(rr):
|
||||||
|
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
||||||
|
|
||||||
|
conditioning = {}
|
||||||
|
model_conds = conds["model_conds"]
|
||||||
|
for c in model_conds:
|
||||||
|
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
||||||
|
|
||||||
|
control = conds.get('control', None)
|
||||||
|
|
||||||
|
patches = None
|
||||||
|
if 'gligen' in conds:
|
||||||
|
gligen = conds['gligen']
|
||||||
|
patches = {}
|
||||||
|
gligen_type = gligen[0]
|
||||||
|
gligen_model = gligen[1]
|
||||||
|
if gligen_type == "position":
|
||||||
|
gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
|
||||||
|
else:
|
||||||
|
gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
|
||||||
|
|
||||||
|
patches['middle_patch'] = [gligen_patch]
|
||||||
|
|
||||||
|
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches'])
|
||||||
|
return cond_obj(input_x, mult, conditioning, area, control, patches)
|
||||||
|
|
||||||
|
def cond_equal_size(c1, c2):
|
||||||
|
if c1 is c2:
|
||||||
|
return True
|
||||||
|
if c1.keys() != c2.keys():
|
||||||
|
return False
|
||||||
|
for k in c1:
|
||||||
|
if not c1[k].can_concat(c2[k]):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def can_concat_cond(c1, c2):
|
||||||
|
if c1.input_x.shape != c2.input_x.shape:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def objects_concatable(obj1, obj2):
|
||||||
|
if (obj1 is None) != (obj2 is None):
|
||||||
|
return False
|
||||||
|
if obj1 is not None:
|
||||||
|
if obj1 is not obj2:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not objects_concatable(c1.control, c2.control):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not objects_concatable(c1.patches, c2.patches):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return cond_equal_size(c1.conditioning, c2.conditioning)
|
||||||
|
|
||||||
|
def cond_cat(c_list):
|
||||||
|
c_crossattn = []
|
||||||
|
c_concat = []
|
||||||
|
c_adm = []
|
||||||
|
crossattn_max_len = 0
|
||||||
|
|
||||||
|
temp = {}
|
||||||
|
for x in c_list:
|
||||||
|
for k in x:
|
||||||
|
cur = temp.get(k, [])
|
||||||
|
cur.append(x[k])
|
||||||
|
temp[k] = cur
|
||||||
|
|
||||||
|
out = {}
|
||||||
|
for k in temp:
|
||||||
|
conds = temp[k]
|
||||||
|
out[k] = conds[0].concat(conds[1:])
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||||
|
out_cond = torch.zeros_like(x_in)
|
||||||
|
out_count = torch.ones_like(x_in) * 1e-37
|
||||||
|
|
||||||
|
out_uncond = torch.zeros_like(x_in)
|
||||||
|
out_uncond_count = torch.ones_like(x_in) * 1e-37
|
||||||
|
|
||||||
|
COND = 0
|
||||||
|
UNCOND = 1
|
||||||
|
|
||||||
|
to_run = []
|
||||||
|
for x in cond:
|
||||||
|
p = get_area_and_mult(x, x_in, timestep)
|
||||||
|
if p is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
to_run += [(p, COND)]
|
||||||
|
if uncond is not None:
|
||||||
|
for x in uncond:
|
||||||
|
p = get_area_and_mult(x, x_in, timestep)
|
||||||
|
if p is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
to_run += [(p, UNCOND)]
|
||||||
|
|
||||||
|
while len(to_run) > 0:
|
||||||
|
first = to_run[0]
|
||||||
|
first_shape = first[0][0].shape
|
||||||
|
to_batch_temp = []
|
||||||
|
for x in range(len(to_run)):
|
||||||
|
if can_concat_cond(to_run[x][0], first[0]):
|
||||||
|
to_batch_temp += [x]
|
||||||
|
|
||||||
|
to_batch_temp.reverse()
|
||||||
|
to_batch = to_batch_temp[:1]
|
||||||
|
|
||||||
|
free_memory = model_management.get_free_memory(x_in.device)
|
||||||
|
for i in range(1, len(to_batch_temp) + 1):
|
||||||
|
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||||
|
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||||
|
if model.memory_required(input_shape) < free_memory:
|
||||||
|
to_batch = batch_amount
|
||||||
|
break
|
||||||
|
|
||||||
|
input_x = []
|
||||||
|
mult = []
|
||||||
|
c = []
|
||||||
|
cond_or_uncond = []
|
||||||
|
area = []
|
||||||
|
control = None
|
||||||
|
patches = None
|
||||||
|
for x in to_batch:
|
||||||
|
o = to_run.pop(x)
|
||||||
|
p = o[0]
|
||||||
|
input_x.append(p.input_x)
|
||||||
|
mult.append(p.mult)
|
||||||
|
c.append(p.conditioning)
|
||||||
|
area.append(p.area)
|
||||||
|
cond_or_uncond.append(o[1])
|
||||||
|
control = p.control
|
||||||
|
patches = p.patches
|
||||||
|
|
||||||
|
batch_chunks = len(cond_or_uncond)
|
||||||
|
input_x = torch.cat(input_x)
|
||||||
|
c = cond_cat(c)
|
||||||
|
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||||
|
|
||||||
|
if control is not None:
|
||||||
|
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
||||||
|
|
||||||
|
transformer_options = {}
|
||||||
|
if 'transformer_options' in model_options:
|
||||||
|
transformer_options = model_options['transformer_options'].copy()
|
||||||
|
|
||||||
|
if patches is not None:
|
||||||
|
if "patches" in transformer_options:
|
||||||
|
cur_patches = transformer_options["patches"].copy()
|
||||||
|
for p in patches:
|
||||||
|
if p in cur_patches:
|
||||||
|
cur_patches[p] = cur_patches[p] + patches[p]
|
||||||
|
else:
|
||||||
|
cur_patches[p] = patches[p]
|
||||||
|
else:
|
||||||
|
transformer_options["patches"] = patches
|
||||||
|
|
||||||
|
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||||
|
transformer_options["sigmas"] = timestep
|
||||||
|
|
||||||
|
c['transformer_options'] = transformer_options
|
||||||
|
|
||||||
|
if 'model_function_wrapper' in model_options:
|
||||||
|
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||||
|
else:
|
||||||
|
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||||
|
del input_x
|
||||||
|
|
||||||
|
for o in range(batch_chunks):
|
||||||
|
if cond_or_uncond[o] == COND:
|
||||||
|
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||||
|
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
||||||
|
else:
|
||||||
|
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||||
|
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
||||||
|
del mult
|
||||||
|
|
||||||
|
out_cond /= out_count
|
||||||
|
del out_count
|
||||||
|
out_uncond /= out_uncond_count
|
||||||
|
del out_uncond_count
|
||||||
|
return out_cond, out_uncond
|
||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns denoised
|
#Returns denoised
|
||||||
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||||
def get_area_and_mult(conds, x_in, timestep_in):
|
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
|
||||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
uncond_ = None
|
||||||
strength = 1.0
|
|
||||||
|
|
||||||
if 'timestep_start' in conds:
|
|
||||||
timestep_start = conds['timestep_start']
|
|
||||||
if timestep_in[0] > timestep_start:
|
|
||||||
return None
|
|
||||||
if 'timestep_end' in conds:
|
|
||||||
timestep_end = conds['timestep_end']
|
|
||||||
if timestep_in[0] < timestep_end:
|
|
||||||
return None
|
|
||||||
if 'area' in conds:
|
|
||||||
area = conds['area']
|
|
||||||
if 'strength' in conds:
|
|
||||||
strength = conds['strength']
|
|
||||||
|
|
||||||
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
|
||||||
if 'mask' in conds:
|
|
||||||
# Scale the mask to the size of the input
|
|
||||||
# The mask should have been resized as we began the sampling process
|
|
||||||
mask_strength = 1.0
|
|
||||||
if "mask_strength" in conds:
|
|
||||||
mask_strength = conds["mask_strength"]
|
|
||||||
mask = conds['mask']
|
|
||||||
assert(mask.shape[1] == x_in.shape[2])
|
|
||||||
assert(mask.shape[2] == x_in.shape[3])
|
|
||||||
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
|
|
||||||
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
|
||||||
else:
|
|
||||||
mask = torch.ones_like(input_x)
|
|
||||||
mult = mask * strength
|
|
||||||
|
|
||||||
if 'mask' not in conds:
|
|
||||||
rr = 8
|
|
||||||
if area[2] != 0:
|
|
||||||
for t in range(rr):
|
|
||||||
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
|
|
||||||
if (area[0] + area[2]) < x_in.shape[2]:
|
|
||||||
for t in range(rr):
|
|
||||||
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
|
|
||||||
if area[3] != 0:
|
|
||||||
for t in range(rr):
|
|
||||||
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
|
|
||||||
if (area[1] + area[3]) < x_in.shape[3]:
|
|
||||||
for t in range(rr):
|
|
||||||
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
|
||||||
|
|
||||||
conditionning = {}
|
|
||||||
model_conds = conds["model_conds"]
|
|
||||||
for c in model_conds:
|
|
||||||
conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
|
||||||
|
|
||||||
control = None
|
|
||||||
if 'control' in conds:
|
|
||||||
control = conds['control']
|
|
||||||
|
|
||||||
patches = None
|
|
||||||
if 'gligen' in conds:
|
|
||||||
gligen = conds['gligen']
|
|
||||||
patches = {}
|
|
||||||
gligen_type = gligen[0]
|
|
||||||
gligen_model = gligen[1]
|
|
||||||
if gligen_type == "position":
|
|
||||||
gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
|
|
||||||
else:
|
|
||||||
gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
|
|
||||||
|
|
||||||
patches['middle_patch'] = [gligen_patch]
|
|
||||||
|
|
||||||
return (input_x, mult, conditionning, area, control, patches)
|
|
||||||
|
|
||||||
def cond_equal_size(c1, c2):
|
|
||||||
if c1 is c2:
|
|
||||||
return True
|
|
||||||
if c1.keys() != c2.keys():
|
|
||||||
return False
|
|
||||||
for k in c1:
|
|
||||||
if not c1[k].can_concat(c2[k]):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def can_concat_cond(c1, c2):
|
|
||||||
if c1[0].shape != c2[0].shape:
|
|
||||||
return False
|
|
||||||
|
|
||||||
#control
|
|
||||||
if (c1[4] is None) != (c2[4] is None):
|
|
||||||
return False
|
|
||||||
if c1[4] is not None:
|
|
||||||
if c1[4] is not c2[4]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
#patches
|
|
||||||
if (c1[5] is None) != (c2[5] is None):
|
|
||||||
return False
|
|
||||||
if (c1[5] is not None):
|
|
||||||
if c1[5] is not c2[5]:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return cond_equal_size(c1[2], c2[2])
|
|
||||||
|
|
||||||
def cond_cat(c_list):
|
|
||||||
c_crossattn = []
|
|
||||||
c_concat = []
|
|
||||||
c_adm = []
|
|
||||||
crossattn_max_len = 0
|
|
||||||
|
|
||||||
temp = {}
|
|
||||||
for x in c_list:
|
|
||||||
for k in x:
|
|
||||||
cur = temp.get(k, [])
|
|
||||||
cur.append(x[k])
|
|
||||||
temp[k] = cur
|
|
||||||
|
|
||||||
out = {}
|
|
||||||
for k in temp:
|
|
||||||
conds = temp[k]
|
|
||||||
out[k] = conds[0].concat(conds[1:])
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
|
||||||
out_cond = torch.zeros_like(x_in)
|
|
||||||
out_count = torch.ones_like(x_in) * 1e-37
|
|
||||||
|
|
||||||
out_uncond = torch.zeros_like(x_in)
|
|
||||||
out_uncond_count = torch.ones_like(x_in) * 1e-37
|
|
||||||
|
|
||||||
COND = 0
|
|
||||||
UNCOND = 1
|
|
||||||
|
|
||||||
to_run = []
|
|
||||||
for x in cond:
|
|
||||||
p = get_area_and_mult(x, x_in, timestep)
|
|
||||||
if p is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
to_run += [(p, COND)]
|
|
||||||
if uncond is not None:
|
|
||||||
for x in uncond:
|
|
||||||
p = get_area_and_mult(x, x_in, timestep)
|
|
||||||
if p is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
to_run += [(p, UNCOND)]
|
|
||||||
|
|
||||||
while len(to_run) > 0:
|
|
||||||
first = to_run[0]
|
|
||||||
first_shape = first[0][0].shape
|
|
||||||
to_batch_temp = []
|
|
||||||
for x in range(len(to_run)):
|
|
||||||
if can_concat_cond(to_run[x][0], first[0]):
|
|
||||||
to_batch_temp += [x]
|
|
||||||
|
|
||||||
to_batch_temp.reverse()
|
|
||||||
to_batch = to_batch_temp[:1]
|
|
||||||
|
|
||||||
free_memory = model_management.get_free_memory(x_in.device)
|
|
||||||
for i in range(1, len(to_batch_temp) + 1):
|
|
||||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
|
||||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
|
||||||
if model.memory_required(input_shape) < free_memory:
|
|
||||||
to_batch = batch_amount
|
|
||||||
break
|
|
||||||
|
|
||||||
input_x = []
|
|
||||||
mult = []
|
|
||||||
c = []
|
|
||||||
cond_or_uncond = []
|
|
||||||
area = []
|
|
||||||
control = None
|
|
||||||
patches = None
|
|
||||||
for x in to_batch:
|
|
||||||
o = to_run.pop(x)
|
|
||||||
p = o[0]
|
|
||||||
input_x += [p[0]]
|
|
||||||
mult += [p[1]]
|
|
||||||
c += [p[2]]
|
|
||||||
area += [p[3]]
|
|
||||||
cond_or_uncond += [o[1]]
|
|
||||||
control = p[4]
|
|
||||||
patches = p[5]
|
|
||||||
|
|
||||||
batch_chunks = len(cond_or_uncond)
|
|
||||||
input_x = torch.cat(input_x)
|
|
||||||
c = cond_cat(c)
|
|
||||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
|
||||||
|
|
||||||
if control is not None:
|
|
||||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
|
||||||
|
|
||||||
transformer_options = {}
|
|
||||||
if 'transformer_options' in model_options:
|
|
||||||
transformer_options = model_options['transformer_options'].copy()
|
|
||||||
|
|
||||||
if patches is not None:
|
|
||||||
if "patches" in transformer_options:
|
|
||||||
cur_patches = transformer_options["patches"].copy()
|
|
||||||
for p in patches:
|
|
||||||
if p in cur_patches:
|
|
||||||
cur_patches[p] = cur_patches[p] + patches[p]
|
|
||||||
else:
|
|
||||||
cur_patches[p] = patches[p]
|
|
||||||
else:
|
|
||||||
transformer_options["patches"] = patches
|
|
||||||
|
|
||||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
|
||||||
transformer_options["sigmas"] = timestep
|
|
||||||
|
|
||||||
c['transformer_options'] = transformer_options
|
|
||||||
|
|
||||||
if 'model_function_wrapper' in model_options:
|
|
||||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
|
||||||
else:
|
|
||||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
|
||||||
del input_x
|
|
||||||
|
|
||||||
for o in range(batch_chunks):
|
|
||||||
if cond_or_uncond[o] == COND:
|
|
||||||
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
|
||||||
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
|
||||||
else:
|
|
||||||
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
|
||||||
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
|
||||||
del mult
|
|
||||||
|
|
||||||
out_cond /= out_count
|
|
||||||
del out_count
|
|
||||||
out_uncond /= out_uncond_count
|
|
||||||
del out_uncond_count
|
|
||||||
return out_cond, out_uncond
|
|
||||||
|
|
||||||
|
|
||||||
if math.isclose(cond_scale, 1.0):
|
|
||||||
uncond = None
|
|
||||||
|
|
||||||
cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
|
|
||||||
if "sampler_cfg_function" in model_options:
|
|
||||||
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
|
|
||||||
return x - model_options["sampler_cfg_function"](args)
|
|
||||||
else:
|
else:
|
||||||
return uncond + (cond - uncond) * cond_scale
|
uncond_ = uncond
|
||||||
|
|
||||||
|
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
|
||||||
|
if "sampler_cfg_function" in model_options:
|
||||||
|
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
|
||||||
|
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
|
||||||
|
cfg_result = x - model_options["sampler_cfg_function"](args)
|
||||||
|
else:
|
||||||
|
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
||||||
|
|
||||||
|
for fn in model_options.get("sampler_post_cfg_function", []):
|
||||||
|
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
|
||||||
|
"sigma": timestep, "model_options": model_options, "input": x}
|
||||||
|
cfg_result = fn(args)
|
||||||
|
|
||||||
|
return cfg_result
|
||||||
|
|
||||||
class CFGNoisePredictor(torch.nn.Module):
|
class CFGNoisePredictor(torch.nn.Module):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
@ -272,10 +280,7 @@ class KSamplerX0Inpaint(torch.nn.Module):
|
|||||||
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
|
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
|
||||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
|
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
out *= denoise_mask
|
out = out * denoise_mask + self.latent_image * latent_mask
|
||||||
|
|
||||||
if denoise_mask is not None:
|
|
||||||
out += self.latent_image * latent_mask
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def simple_scheduler(model, steps):
|
def simple_scheduler(model, steps):
|
||||||
@ -590,6 +595,13 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||||||
calculate_start_end_timesteps(model, negative)
|
calculate_start_end_timesteps(model, negative)
|
||||||
calculate_start_end_timesteps(model, positive)
|
calculate_start_end_timesteps(model, positive)
|
||||||
|
|
||||||
|
if latent_image is not None:
|
||||||
|
latent_image = model.process_latent_in(latent_image)
|
||||||
|
|
||||||
|
if hasattr(model, 'extra_conds'):
|
||||||
|
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||||
|
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||||
|
|
||||||
#make sure each cond area has an opposite one with the same area
|
#make sure each cond area has an opposite one with the same area
|
||||||
for c in positive:
|
for c in positive:
|
||||||
create_cond_with_same_area_if_none(negative, c)
|
create_cond_with_same_area_if_none(negative, c)
|
||||||
@ -601,13 +613,6 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||||||
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||||
|
|
||||||
if latent_image is not None:
|
|
||||||
latent_image = model.process_latent_in(latent_image)
|
|
||||||
|
|
||||||
if hasattr(model, 'extra_conds'):
|
|
||||||
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
|
|
||||||
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
|
|
||||||
|
|
||||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
||||||
|
|
||||||
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||||
@ -630,7 +635,7 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
|
|||||||
elif scheduler_name == "sgm_uniform":
|
elif scheduler_name == "sgm_uniform":
|
||||||
sigmas = normal_scheduler(model, steps, sgm=True)
|
sigmas = normal_scheduler(model, steps, sgm=True)
|
||||||
else:
|
else:
|
||||||
print("error invalid scheduler", self.scheduler)
|
print("error invalid scheduler", scheduler_name)
|
||||||
return sigmas
|
return sigmas
|
||||||
|
|
||||||
def sampler_object(name):
|
def sampler_object(name):
|
||||||
|
|||||||
48
comfy/sd.py
48
comfy/sd.py
@ -148,12 +148,14 @@ class CLIP:
|
|||||||
return self.patcher.get_key_patches()
|
return self.patcher.get_key_patches()
|
||||||
|
|
||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, sd=None, device=None, config=None):
|
def __init__(self, sd=None, device=None, config=None, dtype=None):
|
||||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||||
|
|
||||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
||||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
||||||
|
self.downscale_ratio = 8
|
||||||
|
self.latent_channels = 4
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
if "decoder.mid.block_1.mix_factor" in sd:
|
if "decoder.mid.block_1.mix_factor" in sd:
|
||||||
@ -169,6 +171,11 @@ class VAE:
|
|||||||
else:
|
else:
|
||||||
#default SD1.x/SD2.x VAE parameters
|
#default SD1.x/SD2.x VAE parameters
|
||||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
|
|
||||||
|
if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
|
||||||
|
ddconfig['ch_mult'] = [1, 2, 4]
|
||||||
|
self.downscale_ratio = 4
|
||||||
|
|
||||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
||||||
else:
|
else:
|
||||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||||
@ -185,8 +192,11 @@ class VAE:
|
|||||||
device = model_management.vae_device()
|
device = model_management.vae_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
offload_device = model_management.vae_offload_device()
|
offload_device = model_management.vae_offload_device()
|
||||||
self.vae_dtype = model_management.vae_dtype()
|
if dtype is None:
|
||||||
|
dtype = model_management.vae_dtype()
|
||||||
|
self.vae_dtype = dtype
|
||||||
self.first_stage_model.to(self.vae_dtype)
|
self.first_stage_model.to(self.vae_dtype)
|
||||||
|
self.output_device = model_management.intermediate_device()
|
||||||
|
|
||||||
self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
|
|
||||||
@ -198,9 +208,9 @@ class VAE:
|
|||||||
|
|
||||||
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
|
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
|
||||||
output = torch.clamp((
|
output = torch.clamp((
|
||||||
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
|
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||||
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
|
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||||
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar))
|
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar))
|
||||||
/ 3.0) / 2.0, min=0.0, max=1.0)
|
/ 3.0) / 2.0, min=0.0, max=1.0)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -211,9 +221,9 @@ class VAE:
|
|||||||
pbar = utils.ProgressBar(steps)
|
pbar = utils.ProgressBar(steps)
|
||||||
|
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
|
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
|
||||||
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||||
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||||
samples /= 3.0
|
samples /= 3.0
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
@ -225,15 +235,15 @@ class VAE:
|
|||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / memory_used)
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
|
|
||||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
|
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device)
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||||
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0)
|
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
pixel_samples = self.decode_tiled_(samples_in)
|
pixel_samples = self.decode_tiled_(samples_in)
|
||||||
|
|
||||||
pixel_samples = pixel_samples.cpu().movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
|
||||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||||
@ -249,10 +259,10 @@ class VAE:
|
|||||||
free_memory = model_management.get_free_memory(self.device)
|
free_memory = model_management.get_free_memory(self.device)
|
||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / memory_used)
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
|
samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
|
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
|
||||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
|
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||||
|
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
@ -429,11 +439,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
|
|
||||||
parameters = utils.calculate_parameters(sd, "model.diffusion_model.")
|
parameters = utils.calculate_parameters(sd, "model.diffusion_model.")
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||||
|
load_device = model_management.get_torch_device()
|
||||||
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
|
|
||||||
class WeightsLoader(torch.nn.Module):
|
class WeightsLoader(torch.nn.Module):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
|
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
|
||||||
|
model_config.set_manual_cast(manual_cast_dtype)
|
||||||
|
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||||
|
|
||||||
@ -466,7 +480,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
print("left over keys:", left_over)
|
print("left over keys:", left_over)
|
||||||
|
|
||||||
if output_model:
|
if output_model:
|
||||||
_model_patcher = model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
||||||
if inital_load_device != torch.device("cpu"):
|
if inital_load_device != torch.device("cpu"):
|
||||||
print("loaded straight to GPU")
|
print("loaded straight to GPU")
|
||||||
model_management.load_model_gpu(model_patcher)
|
model_management.load_model_gpu(model_patcher)
|
||||||
@ -477,6 +491,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
def load_unet_state_dict(sd): #load unet in diffusers format
|
def load_unet_state_dict(sd): #load unet in diffusers format
|
||||||
parameters = utils.calculate_parameters(sd)
|
parameters = utils.calculate_parameters(sd)
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||||
|
load_device = model_management.get_torch_device()
|
||||||
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
|
|
||||||
if "input_blocks.0.0.weight" in sd: #ldm
|
if "input_blocks.0.0.weight" in sd: #ldm
|
||||||
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
|
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
@ -497,13 +514,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format
|
|||||||
else:
|
else:
|
||||||
print(diffusers_keys[k], k)
|
print(diffusers_keys[k], k)
|
||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
|
model_config.set_manual_cast(manual_cast_dtype)
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
model = model.to(offload_device)
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
left_over = sd.keys()
|
left_over = sd.keys()
|
||||||
if len(left_over) > 0:
|
if len(left_over) > 0:
|
||||||
print("left over keys in unet:", left_over)
|
print("left over keys in unet:", left_over)
|
||||||
return model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||||
|
|
||||||
def load_unet(unet_path):
|
def load_unet(unet_path):
|
||||||
sd = utils.load_torch_file(unet_path)
|
sd = utils.load_torch_file(unet_path)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils
|
from transformers import CLIPTokenizer
|
||||||
from . import ops
|
from . import ops
|
||||||
import torch
|
import torch
|
||||||
import traceback
|
import traceback
|
||||||
@ -8,6 +8,8 @@ import zipfile
|
|||||||
from . import model_management
|
from . import model_management
|
||||||
from pkg_resources import resource_filename
|
from pkg_resources import resource_filename
|
||||||
import contextlib
|
import contextlib
|
||||||
|
from . import clip_model
|
||||||
|
import json
|
||||||
|
|
||||||
def gen_empty_tokens(special_tokens, length):
|
def gen_empty_tokens(special_tokens, length):
|
||||||
start_token = special_tokens.get("start", None)
|
start_token = special_tokens.get("start", None)
|
||||||
@ -38,7 +40,7 @@ class ClipTokenWeightEncoder:
|
|||||||
|
|
||||||
out, pooled = self.encode(to_encode)
|
out, pooled = self.encode(to_encode)
|
||||||
if pooled is not None:
|
if pooled is not None:
|
||||||
first_pooled = pooled[0:1].cpu()
|
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
||||||
else:
|
else:
|
||||||
first_pooled = pooled
|
first_pooled = pooled
|
||||||
|
|
||||||
@ -55,8 +57,8 @@ class ClipTokenWeightEncoder:
|
|||||||
output.append(z)
|
output.append(z)
|
||||||
|
|
||||||
if (len(output) == 0):
|
if (len(output) == 0):
|
||||||
return out[-1:].cpu(), first_pooled
|
return out[-1:].to(model_management.intermediate_device()), first_pooled
|
||||||
return torch.cat(output, dim=-2).cpu(), first_pooled
|
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
|
||||||
|
|
||||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||||
@ -66,33 +68,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
"hidden"
|
"hidden"
|
||||||
]
|
]
|
||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None,
|
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=clip_model.CLIPTextModel,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig,
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32
|
||||||
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert layer in self.LAYERS
|
assert layer in self.LAYERS
|
||||||
self.num_layers = 12
|
|
||||||
if textmodel_path is not None:
|
|
||||||
self.transformer = model_class.from_pretrained(textmodel_path)
|
|
||||||
else:
|
|
||||||
if textmodel_json_config is None:
|
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
|
||||||
if not os.path.exists(textmodel_json_config):
|
|
||||||
textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json')
|
|
||||||
config = config_class.from_json_file(textmodel_json_config)
|
|
||||||
self.num_layers = config.num_hidden_layers
|
|
||||||
with ops.use_comfy_ops(device, dtype):
|
|
||||||
with modeling_utils.no_init_weights():
|
|
||||||
self.transformer = model_class(config)
|
|
||||||
|
|
||||||
self.inner_name = inner_name
|
if textmodel_json_config is None:
|
||||||
if dtype is not None:
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||||
self.transformer.to(dtype)
|
if not os.path.exists(textmodel_json_config):
|
||||||
inner_model = getattr(self.transformer, self.inner_name)
|
textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json')
|
||||||
if hasattr(inner_model, "embeddings"):
|
|
||||||
inner_model.embeddings.to(torch.float32)
|
with open(textmodel_json_config) as f:
|
||||||
else:
|
config = json.load(f)
|
||||||
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
|
|
||||||
|
self.transformer = model_class(config, dtype, device, ops.manual_cast)
|
||||||
|
self.num_layers = self.transformer.num_layers
|
||||||
|
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
if freeze:
|
if freeze:
|
||||||
@ -107,7 +97,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||||
if layer == "hidden":
|
if layer == "hidden":
|
||||||
assert layer_idx is not None
|
assert layer_idx is not None
|
||||||
assert abs(layer_idx) <= self.num_layers
|
assert abs(layer_idx) < self.num_layers
|
||||||
self.clip_layer(layer_idx)
|
self.clip_layer(layer_idx)
|
||||||
self.layer_default = (self.layer, self.layer_idx)
|
self.layer_default = (self.layer, self.layer_idx)
|
||||||
|
|
||||||
@ -118,7 +108,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
def clip_layer(self, layer_idx):
|
def clip_layer(self, layer_idx):
|
||||||
if abs(layer_idx) >= self.num_layers:
|
if abs(layer_idx) > self.num_layers:
|
||||||
self.layer = "last"
|
self.layer = "last"
|
||||||
else:
|
else:
|
||||||
self.layer = "hidden"
|
self.layer = "hidden"
|
||||||
@ -173,41 +163,31 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||||
tokens = torch.LongTensor(tokens).to(device)
|
tokens = torch.LongTensor(tokens).to(device)
|
||||||
|
|
||||||
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
|
attention_mask = None
|
||||||
precision_scope = torch.autocast
|
if self.enable_attention_masks:
|
||||||
|
attention_mask = torch.zeros_like(tokens)
|
||||||
|
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
|
||||||
|
for x in range(attention_mask.shape[0]):
|
||||||
|
for y in range(attention_mask.shape[1]):
|
||||||
|
attention_mask[x, y] = 1
|
||||||
|
if tokens[x, y] == max_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
||||||
|
self.transformer.set_input_embeddings(backup_embeds)
|
||||||
|
|
||||||
|
if self.layer == "last":
|
||||||
|
z = outputs[0]
|
||||||
else:
|
else:
|
||||||
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
|
z = outputs[1]
|
||||||
|
|
||||||
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
|
if outputs[2] is not None:
|
||||||
attention_mask = None
|
pooled_output = outputs[2].float()
|
||||||
if self.enable_attention_masks:
|
else:
|
||||||
attention_mask = torch.zeros_like(tokens)
|
pooled_output = None
|
||||||
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
|
|
||||||
for x in range(attention_mask.shape[0]):
|
|
||||||
for y in range(attention_mask.shape[1]):
|
|
||||||
attention_mask[x, y] = 1
|
|
||||||
if tokens[x, y] == max_token:
|
|
||||||
break
|
|
||||||
|
|
||||||
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden")
|
if self.text_projection is not None and pooled_output is not None:
|
||||||
self.transformer.set_input_embeddings(backup_embeds)
|
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
||||||
|
|
||||||
if self.layer == "last":
|
|
||||||
z = outputs.last_hidden_state
|
|
||||||
elif self.layer == "pooled":
|
|
||||||
z = outputs.pooler_output[:, None, :]
|
|
||||||
else:
|
|
||||||
z = outputs.hidden_states[self.layer_idx]
|
|
||||||
if self.layer_norm_hidden_state:
|
|
||||||
z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
|
|
||||||
|
|
||||||
if hasattr(outputs, "pooler_output"):
|
|
||||||
pooled_output = outputs.pooler_output.float()
|
|
||||||
else:
|
|
||||||
pooled_output = None
|
|
||||||
|
|
||||||
if self.text_projection is not None and pooled_output is not None:
|
|
||||||
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
|
||||||
return z.float(), pooled_output
|
return z.float(), pooled_output
|
||||||
|
|
||||||
def encode(self, tokens):
|
def encode(self, tokens):
|
||||||
|
|||||||
@ -4,15 +4,15 @@ from . import sd1_clip
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
class SD2ClipHModel(sd1_clip.SDClipModel):
|
class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
|
||||||
if layer == "penultimate":
|
if layer == "penultimate":
|
||||||
layer="hidden"
|
layer="hidden"
|
||||||
layer_idx=23
|
layer_idx=-2
|
||||||
|
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
||||||
if not os.path.exists(textmodel_json_config):
|
if not os.path.exists(textmodel_json_config):
|
||||||
textmodel_json_config = resource_filename('comfy', 'sd2_clip_config.json')
|
textmodel_json_config = resource_filename('comfy', 'sd2_clip_config.json')
|
||||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
||||||
|
|
||||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||||
|
|||||||
@ -3,13 +3,13 @@ import torch
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
class SDXLClipG(sd1_clip.SDClipModel):
|
class SDXLClipG(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
|
||||||
if layer == "penultimate":
|
if layer == "penultimate":
|
||||||
layer="hidden"
|
layer="hidden"
|
||||||
layer_idx=-2
|
layer_idx=-2
|
||||||
|
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype,
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
|
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
@ -37,7 +37,7 @@ class SDXLTokenizer:
|
|||||||
class SDXLClipModel(torch.nn.Module):
|
class SDXLClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
||||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||||
|
|
||||||
def clip_layer(self, layer_idx):
|
def clip_layer(self, layer_idx):
|
||||||
|
|||||||
@ -217,6 +217,16 @@ class SSD1B(SDXL):
|
|||||||
"use_temporal_attention": False,
|
"use_temporal_attention": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class Segmind_Vega(SDXL):
|
||||||
|
unet_config = {
|
||||||
|
"model_channels": 320,
|
||||||
|
"use_linear_in_transformer": True,
|
||||||
|
"transformer_depth": [0, 0, 1, 1, 2, 2],
|
||||||
|
"context_dim": 2048,
|
||||||
|
"adm_in_channels": 2816,
|
||||||
|
"use_temporal_attention": False,
|
||||||
|
}
|
||||||
|
|
||||||
class SVD_img2vid(supported_models_base.BASE):
|
class SVD_img2vid(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"model_channels": 320,
|
"model_channels": 320,
|
||||||
@ -242,5 +252,59 @@ class SVD_img2vid(supported_models_base.BASE):
|
|||||||
def clip_target(self):
|
def clip_target(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]
|
class Stable_Zero123(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"context_dim": 768,
|
||||||
|
"model_channels": 320,
|
||||||
|
"use_linear_in_transformer": False,
|
||||||
|
"adm_in_channels": None,
|
||||||
|
"use_temporal_attention": False,
|
||||||
|
"in_channels": 8,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {
|
||||||
|
"num_heads": 8,
|
||||||
|
"num_head_channels": -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
clip_vision_prefix = "cond_stage_model.model.visual."
|
||||||
|
|
||||||
|
latent_format = latent_formats.SD15
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class SD_X4Upscaler(SD20):
|
||||||
|
unet_config = {
|
||||||
|
"context_dim": 1024,
|
||||||
|
"model_channels": 256,
|
||||||
|
'in_channels': 7,
|
||||||
|
"use_linear_in_transformer": True,
|
||||||
|
"adm_in_channels": None,
|
||||||
|
"use_temporal_attention": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {
|
||||||
|
"disable_self_attentions": [True, True, True, False],
|
||||||
|
"num_classes": 1000,
|
||||||
|
"num_heads": 8,
|
||||||
|
"num_head_channels": -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = latent_formats.SD_X4
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"linear_start": 0.0001,
|
||||||
|
"linear_end": 0.02,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.SD_X4Upscaler(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler]
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -22,6 +22,8 @@ class BASE:
|
|||||||
sampling_settings = {}
|
sampling_settings = {}
|
||||||
latent_format = latent_formats.LatentFormat
|
latent_format = latent_formats.LatentFormat
|
||||||
|
|
||||||
|
manual_cast_dtype = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def matches(s, unet_config):
|
def matches(s, unet_config):
|
||||||
for k in s.unet_config:
|
for k in s.unet_config:
|
||||||
@ -71,3 +73,5 @@ class BASE:
|
|||||||
replace_prefix = {"": "first_stage_model."}
|
replace_prefix = {"": "first_stage_model."}
|
||||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
|
def set_manual_cast(self, manual_cast_dtype):
|
||||||
|
self.manual_cast_dtype = manual_cast_dtype
|
||||||
|
|||||||
@ -7,9 +7,10 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .. import utils
|
from .. import utils
|
||||||
|
from .. import ops
|
||||||
|
|
||||||
def conv(n_in, n_out, **kwargs):
|
def conv(n_in, n_out, **kwargs):
|
||||||
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
return ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||||
|
|
||||||
class Clamp(nn.Module):
|
class Clamp(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -19,7 +20,7 @@ class Block(nn.Module):
|
|||||||
def __init__(self, n_in, n_out):
|
def __init__(self, n_in, n_out):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
|
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
|
||||||
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
self.skip = ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||||
self.fuse = nn.ReLU()
|
self.fuse = nn.ReLU()
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.fuse(self.conv(x) + self.skip(x))
|
return self.fuse(self.conv(x) + self.skip(x))
|
||||||
|
|||||||
@ -378,7 +378,7 @@ def lanczos(samples, width, height):
|
|||||||
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
||||||
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
|
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
|
||||||
result = torch.stack(images)
|
result = torch.stack(images)
|
||||||
return result
|
return result.to(samples.device, samples.dtype)
|
||||||
|
|
||||||
def common_upscale(samples, width, height, upscale_method, crop):
|
def common_upscale(samples, width, height, upscale_method, crop):
|
||||||
if crop == "center":
|
if crop == "center":
|
||||||
@ -407,17 +407,17 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
|||||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None):
|
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||||
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu")
|
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device)
|
||||||
for b in range(samples.shape[0]):
|
for b in range(samples.shape[0]):
|
||||||
s = samples[b:b+1]
|
s = samples[b:b+1]
|
||||||
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
|
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
|
||||||
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
|
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
|
||||||
for y in range(0, s.shape[2], tile_y - overlap):
|
for y in range(0, s.shape[2], tile_y - overlap):
|
||||||
for x in range(0, s.shape[3], tile_x - overlap):
|
for x in range(0, s.shape[3], tile_x - overlap):
|
||||||
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
|
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
|
||||||
|
|
||||||
ps = function(s_in).cpu()
|
ps = function(s_in).to(output_device)
|
||||||
mask = torch.ones_like(ps)
|
mask = torch.ones_like(ps)
|
||||||
feather = round(overlap * upscale_amount)
|
feather = round(overlap * upscale_amount)
|
||||||
for t in range(feather):
|
for t in range(feather):
|
||||||
|
|||||||
@ -291,7 +291,7 @@ class Canny:
|
|||||||
|
|
||||||
def detect_edge(self, image, low_threshold, high_threshold):
|
def detect_edge(self, image, low_threshold, high_threshold):
|
||||||
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
|
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
|
||||||
img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1)
|
img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
|
||||||
return (img_out,)
|
return (img_out,)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
import comfy.samplers
|
from comfy import samplers
|
||||||
import comfy.sample
|
from comfy import sample
|
||||||
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||||
from comfy.cmd import latent_preview
|
from comfy.cmd import latent_preview
|
||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
from comfy import utils
|
||||||
|
|
||||||
|
|
||||||
class BasicScheduler:
|
class BasicScheduler:
|
||||||
@ -11,8 +11,9 @@ class BasicScheduler:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required":
|
return {"required":
|
||||||
{"model": ("MODEL",),
|
{"model": ("MODEL",),
|
||||||
"scheduler": (comfy.samplers.SCHEDULER_NAMES, ),
|
"scheduler": (samplers.SCHEDULER_NAMES, ),
|
||||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||||
|
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
RETURN_TYPES = ("SIGMAS",)
|
RETURN_TYPES = ("SIGMAS",)
|
||||||
@ -20,8 +21,15 @@ class BasicScheduler:
|
|||||||
|
|
||||||
FUNCTION = "get_sigmas"
|
FUNCTION = "get_sigmas"
|
||||||
|
|
||||||
def get_sigmas(self, model, scheduler, steps):
|
def get_sigmas(self, model, scheduler, steps, denoise):
|
||||||
sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, steps).cpu()
|
total_steps = steps
|
||||||
|
if denoise < 1.0:
|
||||||
|
total_steps = int(steps/denoise)
|
||||||
|
|
||||||
|
inner_model = model.patch_model(patch_weights=False)
|
||||||
|
sigmas = samplers.calculate_sigmas_scheduler(inner_model, scheduler, total_steps).cpu()
|
||||||
|
model.unpatch_model()
|
||||||
|
sigmas = sigmas[-(steps + 1):]
|
||||||
return (sigmas, )
|
return (sigmas, )
|
||||||
|
|
||||||
|
|
||||||
@ -87,6 +95,7 @@ class SDTurboScheduler:
|
|||||||
return {"required":
|
return {"required":
|
||||||
{"model": ("MODEL",),
|
{"model": ("MODEL",),
|
||||||
"steps": ("INT", {"default": 1, "min": 1, "max": 10}),
|
"steps": ("INT", {"default": 1, "min": 1, "max": 10}),
|
||||||
|
"denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
RETURN_TYPES = ("SIGMAS",)
|
RETURN_TYPES = ("SIGMAS",)
|
||||||
@ -94,9 +103,12 @@ class SDTurboScheduler:
|
|||||||
|
|
||||||
FUNCTION = "get_sigmas"
|
FUNCTION = "get_sigmas"
|
||||||
|
|
||||||
def get_sigmas(self, model, steps):
|
def get_sigmas(self, model, steps, denoise):
|
||||||
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps]
|
start_step = 10 - int(10 * denoise)
|
||||||
sigmas = model.model.model_sampling.sigma(timesteps)
|
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
|
||||||
|
inner_model = model.patch_model(patch_weights=False)
|
||||||
|
sigmas = inner_model.model_sampling.sigma(timesteps)
|
||||||
|
model.unpatch_model()
|
||||||
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
||||||
return (sigmas, )
|
return (sigmas, )
|
||||||
|
|
||||||
@ -159,7 +171,7 @@ class KSamplerSelect:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required":
|
return {"required":
|
||||||
{"sampler_name": (comfy.samplers.SAMPLER_NAMES, ),
|
{"sampler_name": (samplers.SAMPLER_NAMES, ),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
RETURN_TYPES = ("SAMPLER",)
|
RETURN_TYPES = ("SAMPLER",)
|
||||||
@ -168,7 +180,7 @@ class KSamplerSelect:
|
|||||||
FUNCTION = "get_sampler"
|
FUNCTION = "get_sampler"
|
||||||
|
|
||||||
def get_sampler(self, sampler_name):
|
def get_sampler(self, sampler_name):
|
||||||
sampler = comfy.samplers.sampler_object(sampler_name)
|
sampler = samplers.sampler_object(sampler_name)
|
||||||
return (sampler, )
|
return (sampler, )
|
||||||
|
|
||||||
class SamplerDPMPP_2M_SDE:
|
class SamplerDPMPP_2M_SDE:
|
||||||
@ -191,7 +203,7 @@ class SamplerDPMPP_2M_SDE:
|
|||||||
sampler_name = "dpmpp_2m_sde"
|
sampler_name = "dpmpp_2m_sde"
|
||||||
else:
|
else:
|
||||||
sampler_name = "dpmpp_2m_sde_gpu"
|
sampler_name = "dpmpp_2m_sde_gpu"
|
||||||
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
|
sampler = samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
|
||||||
return (sampler, )
|
return (sampler, )
|
||||||
|
|
||||||
|
|
||||||
@ -215,7 +227,7 @@ class SamplerDPMPP_SDE:
|
|||||||
sampler_name = "dpmpp_sde"
|
sampler_name = "dpmpp_sde"
|
||||||
else:
|
else:
|
||||||
sampler_name = "dpmpp_sde_gpu"
|
sampler_name = "dpmpp_sde_gpu"
|
||||||
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
|
sampler = samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
|
||||||
return (sampler, )
|
return (sampler, )
|
||||||
|
|
||||||
class SamplerCustom:
|
class SamplerCustom:
|
||||||
@ -248,7 +260,7 @@ class SamplerCustom:
|
|||||||
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
||||||
else:
|
else:
|
||||||
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
||||||
noise = comfy.sample.prepare_noise(latent_image, noise_seed, batch_inds)
|
noise = sample.prepare_noise(latent_image, noise_seed, batch_inds)
|
||||||
|
|
||||||
noise_mask = None
|
noise_mask = None
|
||||||
if "noise_mask" in latent:
|
if "noise_mask" in latent:
|
||||||
@ -257,8 +269,8 @@ class SamplerCustom:
|
|||||||
x0_output = {}
|
x0_output = {}
|
||||||
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
|
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
|
||||||
|
|
||||||
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
disable_pbar = not utils.PROGRESS_BAR_ENABLED
|
||||||
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
|
samples = sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
|
||||||
|
|
||||||
out = latent.copy()
|
out = latent.copy()
|
||||||
out["samples"] = samples
|
out["samples"] = samples
|
||||||
|
|||||||
@ -2,9 +2,10 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
import random
|
# Use torch rng for consistency across generations
|
||||||
|
from torch import randint
|
||||||
|
|
||||||
def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int:
|
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
||||||
min_value = min(min_value, value)
|
min_value = min(min_value, value)
|
||||||
|
|
||||||
# All big divisors of value (inclusive)
|
# All big divisors of value (inclusive)
|
||||||
@ -12,8 +13,10 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter
|
|||||||
|
|
||||||
ns = [value // i for i in divisors[:max_options]] # has at least 1 element
|
ns = [value // i for i in divisors[:max_options]] # has at least 1 element
|
||||||
|
|
||||||
random.seed(counter)
|
if len(ns) - 1 > 0:
|
||||||
idx = random.randint(0, len(ns) - 1)
|
idx = randint(low=0, high=len(ns) - 1, size=(1,)).item()
|
||||||
|
else:
|
||||||
|
idx = 0
|
||||||
|
|
||||||
return ns[idx]
|
return ns[idx]
|
||||||
|
|
||||||
@ -42,7 +45,6 @@ class HyperTile:
|
|||||||
|
|
||||||
latent_tile_size = max(32, tile_size) // 8
|
latent_tile_size = max(32, tile_size) // 8
|
||||||
self.temp = None
|
self.temp = None
|
||||||
self.counter = 1
|
|
||||||
|
|
||||||
def hypertile_in(q, k, v, extra_options):
|
def hypertile_in(q, k, v, extra_options):
|
||||||
if q.shape[-1] in apply_to:
|
if q.shape[-1] in apply_to:
|
||||||
@ -53,10 +55,8 @@ class HyperTile:
|
|||||||
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
|
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
|
||||||
|
|
||||||
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1
|
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1
|
||||||
nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter)
|
nh = random_divisor(h, latent_tile_size * factor, swap_size)
|
||||||
self.counter += 1
|
nw = random_divisor(w, latent_tile_size * factor, swap_size)
|
||||||
nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter)
|
|
||||||
self.counter += 1
|
|
||||||
|
|
||||||
if nh * nw > 1:
|
if nh * nw > 1:
|
||||||
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
||||||
|
|||||||
@ -73,7 +73,7 @@ class SaveAnimatedWEBP:
|
|||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "image/animation"
|
||||||
|
|
||||||
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
|
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
|
||||||
method = self.methods.get(method)
|
method = self.methods.get(method)
|
||||||
@ -135,7 +135,7 @@ class SaveAnimatedPNG:
|
|||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "image/animation"
|
||||||
|
|
||||||
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||||
filename_prefix += self.prefix_append
|
filename_prefix += self.prefix_append
|
||||||
|
|||||||
@ -3,9 +3,7 @@ import torch
|
|||||||
|
|
||||||
def reshape_latent_to(target_shape, latent):
|
def reshape_latent_to(target_shape, latent):
|
||||||
if latent.shape[1:] != target_shape[1:]:
|
if latent.shape[1:] != target_shape[1:]:
|
||||||
latent.movedim(1, -1)
|
|
||||||
latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center")
|
latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center")
|
||||||
latent.movedim(-1, 1)
|
|
||||||
return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
|
return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
|
||||||
|
|
||||||
|
|
||||||
@ -102,9 +100,32 @@ class LatentInterpolate:
|
|||||||
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
||||||
return (samples_out,)
|
return (samples_out,)
|
||||||
|
|
||||||
|
class LatentBatch:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "batch"
|
||||||
|
|
||||||
|
CATEGORY = "latent/batch"
|
||||||
|
|
||||||
|
def batch(self, samples1, samples2):
|
||||||
|
samples_out = samples1.copy()
|
||||||
|
s1 = samples1["samples"]
|
||||||
|
s2 = samples2["samples"]
|
||||||
|
|
||||||
|
if s1.shape[1:] != s2.shape[1:]:
|
||||||
|
s2 = comfy.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center")
|
||||||
|
s = torch.cat((s1, s2), dim=0)
|
||||||
|
samples_out["samples"] = s
|
||||||
|
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
|
||||||
|
return (samples_out,)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"LatentAdd": LatentAdd,
|
"LatentAdd": LatentAdd,
|
||||||
"LatentSubtract": LatentSubtract,
|
"LatentSubtract": LatentSubtract,
|
||||||
"LatentMultiply": LatentMultiply,
|
"LatentMultiply": LatentMultiply,
|
||||||
"LatentInterpolate": LatentInterpolate,
|
"LatentInterpolate": LatentInterpolate,
|
||||||
|
"LatentBatch": LatentBatch,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from comfy.nodes.common import MAX_RESOLUTION
|
|||||||
|
|
||||||
|
|
||||||
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
||||||
|
source = source.to(destination.device)
|
||||||
if resize_source:
|
if resize_source:
|
||||||
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
|
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
|
||||||
|
|
||||||
@ -21,7 +22,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
|
|||||||
if mask is None:
|
if mask is None:
|
||||||
mask = torch.ones_like(source)
|
mask = torch.ones_like(source)
|
||||||
else:
|
else:
|
||||||
mask = mask.clone()
|
mask = mask.to(destination.device, copy=True)
|
||||||
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
|
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
|
||||||
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])
|
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])
|
||||||
|
|
||||||
|
|||||||
@ -16,41 +16,19 @@ class LCM(comfy.model_sampling.EPS):
|
|||||||
|
|
||||||
return c_out * x0 + c_skip * model_input
|
return c_out * x0 + c_skip * model_input
|
||||||
|
|
||||||
class ModelSamplingDiscreteDistilled(torch.nn.Module):
|
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
|
||||||
original_timesteps = 50
|
original_timesteps = 50
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, model_config=None):
|
||||||
super().__init__()
|
super().__init__(model_config)
|
||||||
self.sigma_data = 1.0
|
|
||||||
timesteps = 1000
|
|
||||||
beta_start = 0.00085
|
|
||||||
beta_end = 0.012
|
|
||||||
|
|
||||||
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
|
self.skip_steps = self.num_timesteps // self.original_timesteps
|
||||||
alphas = 1.0 - betas
|
|
||||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
|
||||||
|
|
||||||
self.skip_steps = timesteps // self.original_timesteps
|
sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
|
|
||||||
for x in range(self.original_timesteps):
|
for x in range(self.original_timesteps):
|
||||||
alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
|
sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps]
|
||||||
|
|
||||||
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
|
self.set_sigmas(sigmas_valid)
|
||||||
self.set_sigmas(sigmas)
|
|
||||||
|
|
||||||
def set_sigmas(self, sigmas):
|
|
||||||
self.register_buffer('sigmas', sigmas)
|
|
||||||
self.register_buffer('log_sigmas', sigmas.log())
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sigma_min(self):
|
|
||||||
return self.sigmas[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sigma_max(self):
|
|
||||||
return self.sigmas[-1]
|
|
||||||
|
|
||||||
def timestep(self, sigma):
|
def timestep(self, sigma):
|
||||||
log_sigma = sigma.log()
|
log_sigma = sigma.log()
|
||||||
@ -65,14 +43,6 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module):
|
|||||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||||
return log_sigma.exp().to(timestep.device)
|
return log_sigma.exp().to(timestep.device)
|
||||||
|
|
||||||
def percent_to_sigma(self, percent):
|
|
||||||
if percent <= 0.0:
|
|
||||||
return 999999999.9
|
|
||||||
if percent >= 1.0:
|
|
||||||
return 0.0
|
|
||||||
percent = 1.0 - percent
|
|
||||||
return self.sigma(torch.tensor(percent * 999.0)).item()
|
|
||||||
|
|
||||||
|
|
||||||
def rescale_zero_terminal_snr_sigmas(sigmas):
|
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||||
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||||
@ -121,7 +91,7 @@ class ModelSamplingDiscrete:
|
|||||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
model_sampling = ModelSamplingAdvanced()
|
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||||
if zsnr:
|
if zsnr:
|
||||||
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
|
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
|
||||||
|
|
||||||
@ -153,7 +123,7 @@ class ModelSamplingContinuousEDM:
|
|||||||
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
|
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
model_sampling = ModelSamplingAdvanced()
|
model_sampling = ModelSamplingAdvanced(model.model.model_config)
|
||||||
model_sampling.set_sigma_range(sigma_min, sigma_max)
|
model_sampling.set_sigma_range(sigma_min, sigma_max)
|
||||||
m.add_object_patch("model_sampling", model_sampling)
|
m.add_object_patch("model_sampling", model_sampling)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|||||||
53
comfy_extras/nodes/nodes_perpneg.py
Normal file
53
comfy_extras/nodes/nodes_perpneg.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import torch
|
||||||
|
from comfy import sample
|
||||||
|
from comfy import samplers
|
||||||
|
|
||||||
|
|
||||||
|
class PerpNeg:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"model": ("MODEL", ),
|
||||||
|
"empty_conditioning": ("CONDITIONING", ),
|
||||||
|
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def patch(self, model, empty_conditioning, neg_scale):
|
||||||
|
m = model.clone()
|
||||||
|
nocond = sample.convert_cond(empty_conditioning)
|
||||||
|
|
||||||
|
def cfg_function(args):
|
||||||
|
model = args["model"]
|
||||||
|
noise_pred_pos = args["cond_denoised"]
|
||||||
|
noise_pred_neg = args["uncond_denoised"]
|
||||||
|
cond_scale = args["cond_scale"]
|
||||||
|
x = args["input"]
|
||||||
|
sigma = args["sigma"]
|
||||||
|
model_options = args["model_options"]
|
||||||
|
nocond_processed = samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
|
||||||
|
|
||||||
|
(noise_pred_nocond, _) = samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options)
|
||||||
|
|
||||||
|
pos = noise_pred_pos - noise_pred_nocond
|
||||||
|
neg = noise_pred_neg - noise_pred_nocond
|
||||||
|
perp = ((torch.mul(pos, neg).sum())/(torch.norm(neg)**2)) * neg
|
||||||
|
perp_neg = perp * neg_scale
|
||||||
|
cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg)
|
||||||
|
cfg_result = x - cfg_result
|
||||||
|
return cfg_result
|
||||||
|
|
||||||
|
m.set_model_sampler_cfg_function(cfg_function)
|
||||||
|
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"PerpNeg": PerpNeg,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"PerpNeg": "Perp-Neg",
|
||||||
|
}
|
||||||
@ -226,7 +226,7 @@ class Sharpen:
|
|||||||
batch_size, height, width, channels = image.shape
|
batch_size, height, width, channels = image.shape
|
||||||
|
|
||||||
kernel_size = sharpen_radius * 2 + 1
|
kernel_size = sharpen_radius * 2 + 1
|
||||||
kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10)
|
kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
|
||||||
center = kernel_size // 2
|
center = kernel_size // 2
|
||||||
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
|
||||||
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
|
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
|
||||||
|
|||||||
@ -99,10 +99,40 @@ class LatentRebatch:
|
|||||||
|
|
||||||
return (output_list,)
|
return (output_list,)
|
||||||
|
|
||||||
|
class ImageRebatch:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "images": ("IMAGE",),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
INPUT_IS_LIST = True
|
||||||
|
OUTPUT_IS_LIST = (True, )
|
||||||
|
|
||||||
|
FUNCTION = "rebatch"
|
||||||
|
|
||||||
|
CATEGORY = "image/batch"
|
||||||
|
|
||||||
|
def rebatch(self, images, batch_size):
|
||||||
|
batch_size = batch_size[0]
|
||||||
|
|
||||||
|
output_list = []
|
||||||
|
all_images = []
|
||||||
|
for img in images:
|
||||||
|
for i in range(img.shape[0]):
|
||||||
|
all_images.append(img[i:i+1])
|
||||||
|
|
||||||
|
for i in range(0, len(all_images), batch_size):
|
||||||
|
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
|
||||||
|
|
||||||
|
return (output_list,)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"RebatchLatents": LatentRebatch,
|
"RebatchLatents": LatentRebatch,
|
||||||
|
"RebatchImages": ImageRebatch,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"RebatchLatents": "Rebatch Latents",
|
"RebatchLatents": "Rebatch Latents",
|
||||||
}
|
"RebatchImages": "Rebatch Images",
|
||||||
|
}
|
||||||
|
|||||||
168
comfy_extras/nodes/nodes_sag.py
Normal file
168
comfy_extras/nodes/nodes_sag.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
import torch
|
||||||
|
from torch import einsum
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
import os
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
|
||||||
|
from comfy import samplers
|
||||||
|
|
||||||
|
# from comfy/ldm/modules/attention.py
|
||||||
|
# but modified to return attention scores as well as output
|
||||||
|
def attention_basic_with_sim(q, k, v, heads, mask=None):
|
||||||
|
b, _, dim_head = q.shape
|
||||||
|
dim_head //= heads
|
||||||
|
scale = dim_head ** -0.5
|
||||||
|
|
||||||
|
h = heads
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.unsqueeze(3)
|
||||||
|
.reshape(b, -1, heads, dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b * heads, -1, dim_head)
|
||||||
|
.contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
|
# force cast to fp32 to avoid overflowing
|
||||||
|
if _ATTN_PRECISION =="fp32":
|
||||||
|
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
||||||
|
else:
|
||||||
|
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||||
|
|
||||||
|
del q, k
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = rearrange(mask, 'b ... -> b (...)')
|
||||||
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
|
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||||
|
sim.masked_fill_(~mask, max_neg_value)
|
||||||
|
|
||||||
|
# attention, what we cannot get enough of
|
||||||
|
sim = sim.softmax(dim=-1)
|
||||||
|
|
||||||
|
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
||||||
|
out = (
|
||||||
|
out.unsqueeze(0)
|
||||||
|
.reshape(b, heads, -1, dim_head)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
return (out, sim)
|
||||||
|
|
||||||
|
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
|
||||||
|
# reshape and GAP the attention map
|
||||||
|
_, hw1, hw2 = attn.shape
|
||||||
|
b, _, lh, lw = x0.shape
|
||||||
|
attn = attn.reshape(b, -1, hw1, hw2)
|
||||||
|
# Global Average Pool
|
||||||
|
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
||||||
|
ratio = math.ceil(math.sqrt(lh * lw / hw1))
|
||||||
|
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
|
||||||
|
|
||||||
|
# Reshape
|
||||||
|
mask = (
|
||||||
|
mask.reshape(b, *mid_shape)
|
||||||
|
.unsqueeze(1)
|
||||||
|
.type(attn.dtype)
|
||||||
|
)
|
||||||
|
# Upsample
|
||||||
|
mask = F.interpolate(mask, (lh, lw))
|
||||||
|
|
||||||
|
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
|
||||||
|
blurred = blurred * mask + x0 * (1 - mask)
|
||||||
|
return blurred
|
||||||
|
|
||||||
|
def gaussian_blur_2d(img, kernel_size, sigma):
|
||||||
|
ksize_half = (kernel_size - 1) * 0.5
|
||||||
|
|
||||||
|
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
|
||||||
|
|
||||||
|
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
||||||
|
|
||||||
|
x_kernel = pdf / pdf.sum()
|
||||||
|
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
|
||||||
|
|
||||||
|
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
|
||||||
|
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
|
||||||
|
|
||||||
|
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
|
||||||
|
|
||||||
|
img = F.pad(img, padding, mode="reflect")
|
||||||
|
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
|
||||||
|
return img
|
||||||
|
|
||||||
|
class SelfAttentionGuidance:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}),
|
||||||
|
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def patch(self, model, scale, blur_sigma):
|
||||||
|
m = model.clone()
|
||||||
|
|
||||||
|
attn_scores = None
|
||||||
|
|
||||||
|
# TODO: make this work properly with chunked batches
|
||||||
|
# currently, we can only save the attn from one UNet call
|
||||||
|
def attn_and_record(q, k, v, extra_options):
|
||||||
|
nonlocal attn_scores
|
||||||
|
# if uncond, save the attention scores
|
||||||
|
heads = extra_options["n_heads"]
|
||||||
|
cond_or_uncond = extra_options["cond_or_uncond"]
|
||||||
|
b = q.shape[0] // len(cond_or_uncond)
|
||||||
|
if 1 in cond_or_uncond:
|
||||||
|
uncond_index = cond_or_uncond.index(1)
|
||||||
|
# do the entire attention operation, but save the attention scores to attn_scores
|
||||||
|
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads)
|
||||||
|
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
|
||||||
|
n_slices = heads * b
|
||||||
|
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
return optimized_attention(q, k, v, heads=heads)
|
||||||
|
|
||||||
|
def post_cfg_function(args):
|
||||||
|
nonlocal attn_scores
|
||||||
|
uncond_attn = attn_scores
|
||||||
|
|
||||||
|
sag_scale = scale
|
||||||
|
sag_sigma = blur_sigma
|
||||||
|
sag_threshold = 1.0
|
||||||
|
model = args["model"]
|
||||||
|
uncond_pred = args["uncond_denoised"]
|
||||||
|
uncond = args["uncond"]
|
||||||
|
cfg_result = args["denoised"]
|
||||||
|
sigma = args["sigma"]
|
||||||
|
model_options = args["model_options"]
|
||||||
|
x = args["input"]
|
||||||
|
|
||||||
|
# create the adversarially blurred image
|
||||||
|
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
|
||||||
|
degraded_noised = degraded + x - uncond_pred
|
||||||
|
# call into the UNet
|
||||||
|
(sag, _) = samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
|
||||||
|
return cfg_result + (degraded - sag) * sag_scale
|
||||||
|
|
||||||
|
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
|
||||||
|
|
||||||
|
# from diffusers:
|
||||||
|
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
|
||||||
|
m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
|
||||||
|
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"SelfAttentionGuidance": SelfAttentionGuidance,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"SelfAttentionGuidance": "Self-Attention Guidance",
|
||||||
|
}
|
||||||
46
comfy_extras/nodes/nodes_sdupscale.py
Normal file
46
comfy_extras/nodes/nodes_sdupscale.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import torch
|
||||||
|
from comfy import utils
|
||||||
|
|
||||||
|
class SD_4XUpscale_Conditioning:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "images": ("IMAGE",),
|
||||||
|
"positive": ("CONDITIONING",),
|
||||||
|
"negative": ("CONDITIONING",),
|
||||||
|
"scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent")
|
||||||
|
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/upscale_diffusion"
|
||||||
|
|
||||||
|
def encode(self, images, positive, negative, scale_ratio, noise_augmentation):
|
||||||
|
width = max(1, round(images.shape[-2] * scale_ratio))
|
||||||
|
height = max(1, round(images.shape[-3] * scale_ratio))
|
||||||
|
|
||||||
|
pixels = utils.common_upscale((images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center")
|
||||||
|
|
||||||
|
out_cp = []
|
||||||
|
out_cn = []
|
||||||
|
|
||||||
|
for t in positive:
|
||||||
|
n = [t[0], t[1].copy()]
|
||||||
|
n[1]['concat_image'] = pixels
|
||||||
|
n[1]['noise_augmentation'] = noise_augmentation
|
||||||
|
out_cp.append(n)
|
||||||
|
|
||||||
|
for t in negative:
|
||||||
|
n = [t[0], t[1].copy()]
|
||||||
|
n[1]['concat_image'] = pixels
|
||||||
|
n[1]['noise_augmentation'] = noise_augmentation
|
||||||
|
out_cn.append(n)
|
||||||
|
|
||||||
|
latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
|
||||||
|
return (out_cp, out_cn, {"samples":latent})
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning,
|
||||||
|
}
|
||||||
61
comfy_extras/nodes/nodes_stable3d.py
Normal file
61
comfy_extras/nodes/nodes_stable3d.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
from comfy.nodes.common import MAX_RESOLUTION
|
||||||
|
from comfy import utils
|
||||||
|
|
||||||
|
|
||||||
|
def camera_embeddings(elevation, azimuth):
|
||||||
|
elevation = torch.as_tensor([elevation])
|
||||||
|
azimuth = torch.as_tensor([azimuth])
|
||||||
|
embeddings = torch.stack(
|
||||||
|
[
|
||||||
|
torch.deg2rad(
|
||||||
|
(90 - elevation) - (90)
|
||||||
|
), # Zero123 polar is 90-elevation
|
||||||
|
torch.sin(torch.deg2rad(azimuth)),
|
||||||
|
torch.cos(torch.deg2rad(azimuth)),
|
||||||
|
torch.deg2rad(
|
||||||
|
90 - torch.full_like(elevation, 0)
|
||||||
|
),
|
||||||
|
], dim=-1).unsqueeze(1)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class StableZero123_Conditioning:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "clip_vision": ("CLIP_VISION",),
|
||||||
|
"init_image": ("IMAGE",),
|
||||||
|
"vae": ("VAE",),
|
||||||
|
"width": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
|
||||||
|
"height": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
|
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||||
|
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
||||||
|
RETURN_NAMES = ("positive", "negative", "latent")
|
||||||
|
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/3d_models"
|
||||||
|
|
||||||
|
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth):
|
||||||
|
output = clip_vision.encode_image(init_image)
|
||||||
|
pooled = output.image_embeds.unsqueeze(0)
|
||||||
|
pixels = utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
|
||||||
|
encode_pixels = pixels[:,:,:,:3]
|
||||||
|
t = vae.encode(encode_pixels)
|
||||||
|
cam_embeds = camera_embeddings(elevation, azimuth)
|
||||||
|
cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1)
|
||||||
|
|
||||||
|
positive = [[cond, {"concat_latent_image": t}]]
|
||||||
|
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
|
||||||
|
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
|
||||||
|
return (positive, negative, {"samples":latent})
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"StableZero123_Conditioning": StableZero123_Conditioning,
|
||||||
|
}
|
||||||
@ -3,6 +3,7 @@ torchaudio
|
|||||||
torchvision
|
torchvision
|
||||||
torchdiffeq>=0.2.3
|
torchdiffeq>=0.2.3
|
||||||
torchsde>=0.2.6
|
torchsde>=0.2.6
|
||||||
|
torchvision
|
||||||
einops>=0.6.0
|
einops>=0.6.0
|
||||||
open-clip-torch>=2.16.0
|
open-clip-torch>=2.16.0
|
||||||
transformers>=4.29.1
|
transformers>=4.29.1
|
||||||
|
|||||||
19
setup.py
19
setup.py
@ -28,18 +28,18 @@ version = '0.0.1'
|
|||||||
"""
|
"""
|
||||||
The package index to the torch built with AMD ROCm.
|
The package index to the torch built with AMD ROCm.
|
||||||
"""
|
"""
|
||||||
amd_torch_index = "https://download.pytorch.org/whl/rocm5.6"
|
amd_torch_index = ("https://download.pytorch.org/whl/rocm5.6", "https://download.pytorch.org/whl/nightly/rocm5.7")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
The package index to torch built with CUDA.
|
The package index to torch built with CUDA.
|
||||||
Observe the CUDA version is in this URL.
|
Observe the CUDA version is in this URL.
|
||||||
"""
|
"""
|
||||||
nvidia_torch_index = "https://download.pytorch.org/whl/cu121"
|
nvidia_torch_index = ("https://download.pytorch.org/whl/cu121", "https://download.pytorch.org/whl/nightly/cu121")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
The package index to torch built against CPU features.
|
The package index to torch built against CPU features.
|
||||||
"""
|
"""
|
||||||
cpu_torch_index = "https://download.pytorch.org/whl/cpu"
|
cpu_torch_index = ("https://download.pytorch.org/whl/cpu", "https://download.pytorch.org/whl/nightly/cpu")
|
||||||
|
|
||||||
# xformers not required for new torch
|
# xformers not required for new torch
|
||||||
|
|
||||||
@ -102,11 +102,11 @@ def _is_linux_arm64():
|
|||||||
|
|
||||||
def dependencies() -> List[str]:
|
def dependencies() -> List[str]:
|
||||||
_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements.txt")).readlines()
|
_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements.txt")).readlines()
|
||||||
# todo: also add all plugin dependencies
|
|
||||||
_alternative_indices = [amd_torch_index, nvidia_torch_index]
|
_alternative_indices = [amd_torch_index, nvidia_torch_index]
|
||||||
session = PipSession()
|
session = PipSession()
|
||||||
|
|
||||||
index_urls = ['https://pypi.org/simple']
|
# (stable, nightly) tuple
|
||||||
|
index_urls = [('https://pypi.org/simple', 'https://pypi.org/simple')]
|
||||||
# prefer nvidia over AMD because AM5/iGPU systems will have a valid ROCm device
|
# prefer nvidia over AMD because AM5/iGPU systems will have a valid ROCm device
|
||||||
if _is_nvidia():
|
if _is_nvidia():
|
||||||
index_urls += [nvidia_torch_index]
|
index_urls += [nvidia_torch_index]
|
||||||
@ -118,6 +118,13 @@ def dependencies() -> List[str]:
|
|||||||
if len(index_urls) == 1:
|
if len(index_urls) == 1:
|
||||||
return _dependencies
|
return _dependencies
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 12):
|
||||||
|
# use the nightlies
|
||||||
|
index_urls = [nightly for (_, nightly) in index_urls]
|
||||||
|
_alternative_indices = [nightly for (_, nightly) in _alternative_indices]
|
||||||
|
else:
|
||||||
|
index_urls = [stable for (stable, _) in index_urls]
|
||||||
|
_alternative_indices = [stable for (stable, _) in _alternative_indices]
|
||||||
try:
|
try:
|
||||||
# pip 23
|
# pip 23
|
||||||
finder = PackageFinder.create(LinkCollector(session, SearchScope([], index_urls, no_index=False)),
|
finder = PackageFinder.create(LinkCollector(session, SearchScope([], index_urls, no_index=False)),
|
||||||
@ -149,7 +156,7 @@ setup(
|
|||||||
description="",
|
description="",
|
||||||
author="",
|
author="",
|
||||||
version=version,
|
version=version,
|
||||||
python_requires=">=3.9,<3.12",
|
python_requires=">=3.9,<3.13",
|
||||||
# todo: figure out how to include the web directory to eventually let main live inside the package
|
# todo: figure out how to include the web directory to eventually let main live inside the package
|
||||||
# todo: see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ for more about adding plugins
|
# todo: see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ for more about adding plugins
|
||||||
packages=find_packages(exclude=[] if is_editable else ['custom_nodes']),
|
packages=find_packages(exclude=[] if is_editable else ['custom_nodes']),
|
||||||
|
|||||||
9
tests-ui/afterSetup.js
Normal file
9
tests-ui/afterSetup.js
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
const { start } = require("./utils");
|
||||||
|
const lg = require("./utils/litegraph");
|
||||||
|
|
||||||
|
// Load things once per test file before to ensure its all warmed up for the tests
|
||||||
|
beforeAll(async () => {
|
||||||
|
lg.setup(global);
|
||||||
|
await start({ resetEnv: true });
|
||||||
|
lg.teardown(global);
|
||||||
|
});
|
||||||
@ -2,8 +2,10 @@
|
|||||||
const config = {
|
const config = {
|
||||||
testEnvironment: "jsdom",
|
testEnvironment: "jsdom",
|
||||||
setupFiles: ["./globalSetup.js"],
|
setupFiles: ["./globalSetup.js"],
|
||||||
|
setupFilesAfterEnv: ["./afterSetup.js"],
|
||||||
clearMocks: true,
|
clearMocks: true,
|
||||||
resetModules: true,
|
resetModules: true,
|
||||||
|
testTimeout: 10000
|
||||||
};
|
};
|
||||||
|
|
||||||
module.exports = config;
|
module.exports = config;
|
||||||
|
|||||||
@ -52,7 +52,7 @@ describe("extensions", () => {
|
|||||||
const nodeNames = Object.keys(defs);
|
const nodeNames = Object.keys(defs);
|
||||||
const nodeCount = nodeNames.length;
|
const nodeCount = nodeNames.length;
|
||||||
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
|
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
|
||||||
for (let i = 0; i < nodeCount; i++) {
|
for (let i = 0; i < 10; i++) {
|
||||||
// It should be send the JS class and the original JSON definition
|
// It should be send the JS class and the original JSON definition
|
||||||
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0];
|
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0];
|
||||||
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1];
|
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1];
|
||||||
@ -133,7 +133,7 @@ describe("extensions", () => {
|
|||||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2);
|
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2);
|
||||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1);
|
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1);
|
||||||
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2);
|
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2);
|
||||||
});
|
}, 15000);
|
||||||
|
|
||||||
it("allows custom nodeDefs and widgets to be registered", async () => {
|
it("allows custom nodeDefs and widgets to be registered", async () => {
|
||||||
const widgetMock = jest.fn((node, inputName, inputData, app) => {
|
const widgetMock = jest.fn((node, inputName, inputData, app) => {
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
// @ts-check
|
// @ts-check
|
||||||
/// <reference path="../node_modules/@types/jest/index.d.ts" />
|
/// <reference path="../node_modules/@types/jest/index.d.ts" />
|
||||||
|
|
||||||
const { start, createDefaultWorkflow } = require("../utils");
|
const { start, createDefaultWorkflow, getNodeDef, checkBeforeAndAfterReload } = require("../utils");
|
||||||
const lg = require("../utils/litegraph");
|
const lg = require("../utils/litegraph");
|
||||||
|
|
||||||
describe("group node", () => {
|
describe("group node", () => {
|
||||||
@ -273,7 +273,7 @@ describe("group node", () => {
|
|||||||
|
|
||||||
let reroutes = [];
|
let reroutes = [];
|
||||||
let prevNode = nodes.ckpt;
|
let prevNode = nodes.ckpt;
|
||||||
for(let i = 0; i < 5; i++) {
|
for (let i = 0; i < 5; i++) {
|
||||||
const reroute = ez.Reroute();
|
const reroute = ez.Reroute();
|
||||||
prevNode.outputs[0].connectTo(reroute.inputs[0]);
|
prevNode.outputs[0].connectTo(reroute.inputs[0]);
|
||||||
prevNode = reroute;
|
prevNode = reroute;
|
||||||
@ -283,7 +283,7 @@ describe("group node", () => {
|
|||||||
|
|
||||||
const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]);
|
const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]);
|
||||||
expect((await graph.toPrompt()).output).toEqual(getOutput());
|
expect((await graph.toPrompt()).output).toEqual(getOutput());
|
||||||
|
|
||||||
group.menu["Convert to nodes"].call();
|
group.menu["Convert to nodes"].call();
|
||||||
expect((await graph.toPrompt()).output).toEqual(getOutput());
|
expect((await graph.toPrompt()).output).toEqual(getOutput());
|
||||||
});
|
});
|
||||||
@ -383,6 +383,43 @@ describe("group node", () => {
|
|||||||
getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id])
|
getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id])
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
test("groups can connect to each other via internal reroutes", async () => {
|
||||||
|
const { ez, graph, app } = await start();
|
||||||
|
|
||||||
|
const latent = ez.EmptyLatentImage();
|
||||||
|
const vae = ez.VAELoader();
|
||||||
|
const latentReroute = ez.Reroute();
|
||||||
|
const vaeReroute = ez.Reroute();
|
||||||
|
|
||||||
|
latent.outputs[0].connectTo(latentReroute.inputs[0]);
|
||||||
|
vae.outputs[0].connectTo(vaeReroute.inputs[0]);
|
||||||
|
|
||||||
|
const group1 = await convertToGroup(app, graph, "test", [latentReroute, vaeReroute]);
|
||||||
|
group1.menu.Clone.call();
|
||||||
|
expect(app.graph._nodes).toHaveLength(4);
|
||||||
|
const group2 = graph.find(app.graph._nodes[3]);
|
||||||
|
expect(group2.node.type).toEqual("workflow/test");
|
||||||
|
expect(group2.id).not.toEqual(group1.id);
|
||||||
|
|
||||||
|
group1.outputs.VAE.connectTo(group2.inputs.VAE);
|
||||||
|
group1.outputs.LATENT.connectTo(group2.inputs.LATENT);
|
||||||
|
|
||||||
|
const decode = ez.VAEDecode(group2.outputs.LATENT, group2.outputs.VAE);
|
||||||
|
const preview = ez.PreviewImage(decode.outputs[0]);
|
||||||
|
|
||||||
|
const output = {
|
||||||
|
[latent.id]: { inputs: { width: 512, height: 512, batch_size: 1 }, class_type: "EmptyLatentImage" },
|
||||||
|
[vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" },
|
||||||
|
[decode.id]: { inputs: { samples: [latent.id + "", 0], vae: [vae.id + "", 0] }, class_type: "VAEDecode" },
|
||||||
|
[preview.id]: { inputs: { images: [decode.id + "", 0] }, class_type: "PreviewImage" },
|
||||||
|
};
|
||||||
|
expect((await graph.toPrompt()).output).toEqual(output);
|
||||||
|
|
||||||
|
// Ensure missing connections dont cause errors
|
||||||
|
group2.inputs.VAE.disconnect();
|
||||||
|
delete output[decode.id].inputs.vae;
|
||||||
|
expect((await graph.toPrompt()).output).toEqual(output);
|
||||||
|
});
|
||||||
test("displays generated image on group node", async () => {
|
test("displays generated image on group node", async () => {
|
||||||
const { ez, graph, app } = await start();
|
const { ez, graph, app } = await start();
|
||||||
const nodes = createDefaultWorkflow(ez, graph);
|
const nodes = createDefaultWorkflow(ez, graph);
|
||||||
@ -642,6 +679,55 @@ describe("group node", () => {
|
|||||||
2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
|
2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
test("correctly handles widget inputs", async () => {
|
||||||
|
const { ez, graph, app } = await start();
|
||||||
|
const upscaleMethods = (await getNodeDef("ImageScaleBy")).input.required["upscale_method"][0];
|
||||||
|
|
||||||
|
const image = ez.LoadImage();
|
||||||
|
const scale1 = ez.ImageScaleBy(image.outputs[0]);
|
||||||
|
const scale2 = ez.ImageScaleBy(image.outputs[0]);
|
||||||
|
const preview1 = ez.PreviewImage(scale1.outputs[0]);
|
||||||
|
const preview2 = ez.PreviewImage(scale2.outputs[0]);
|
||||||
|
scale1.widgets.upscale_method.value = upscaleMethods[1];
|
||||||
|
scale1.widgets.upscale_method.convertToInput();
|
||||||
|
|
||||||
|
const group = await convertToGroup(app, graph, "test", [scale1, scale2]);
|
||||||
|
expect(group.inputs.length).toBe(3);
|
||||||
|
expect(group.inputs[0].input.type).toBe("IMAGE");
|
||||||
|
expect(group.inputs[1].input.type).toBe("IMAGE");
|
||||||
|
expect(group.inputs[2].input.type).toBe("COMBO");
|
||||||
|
|
||||||
|
// Ensure links are maintained
|
||||||
|
expect(group.inputs[0].connection?.originNode?.id).toBe(image.id);
|
||||||
|
expect(group.inputs[1].connection?.originNode?.id).toBe(image.id);
|
||||||
|
expect(group.inputs[2].connection).toBeFalsy();
|
||||||
|
|
||||||
|
// Ensure primitive gets correct type
|
||||||
|
const primitive = ez.PrimitiveNode();
|
||||||
|
primitive.outputs[0].connectTo(group.inputs[2]);
|
||||||
|
expect(primitive.widgets.value.widget.options.values).toBe(upscaleMethods);
|
||||||
|
expect(primitive.widgets.value.value).toBe(upscaleMethods[1]); // Ensure value is copied
|
||||||
|
primitive.widgets.value.value = upscaleMethods[1];
|
||||||
|
|
||||||
|
await checkBeforeAndAfterReload(graph, async (r) => {
|
||||||
|
const scale1id = r ? `${group.id}:0` : scale1.id;
|
||||||
|
const scale2id = r ? `${group.id}:1` : scale2.id;
|
||||||
|
// Ensure widget value is applied to prompt
|
||||||
|
expect((await graph.toPrompt()).output).toStrictEqual({
|
||||||
|
[image.id]: { inputs: { image: "example.png", upload: "image" }, class_type: "LoadImage" },
|
||||||
|
[scale1id]: {
|
||||||
|
inputs: { upscale_method: upscaleMethods[1], scale_by: 1, image: [`${image.id}`, 0] },
|
||||||
|
class_type: "ImageScaleBy",
|
||||||
|
},
|
||||||
|
[scale2id]: {
|
||||||
|
inputs: { upscale_method: "nearest-exact", scale_by: 1, image: [`${image.id}`, 0] },
|
||||||
|
class_type: "ImageScaleBy",
|
||||||
|
},
|
||||||
|
[preview1.id]: { inputs: { images: [`${scale1id}`, 0] }, class_type: "PreviewImage" },
|
||||||
|
[preview2.id]: { inputs: { images: [`${scale2id}`, 0] }, class_type: "PreviewImage" },
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
test("adds widgets in node execution order", async () => {
|
test("adds widgets in node execution order", async () => {
|
||||||
const { ez, graph, app } = await start();
|
const { ez, graph, app } = await start();
|
||||||
const scale = ez.LatentUpscale();
|
const scale = ez.LatentUpscale();
|
||||||
@ -815,4 +901,105 @@ describe("group node", () => {
|
|||||||
expect(p2.widgets.control_after_generate.value).toBe("randomize");
|
expect(p2.widgets.control_after_generate.value).toBe("randomize");
|
||||||
expect(p2.widgets.control_filter_list.value).toBe("/.+/");
|
expect(p2.widgets.control_filter_list.value).toBe("/.+/");
|
||||||
});
|
});
|
||||||
|
test("internal reroutes work with converted inputs and merge options", async () => {
|
||||||
|
const { ez, graph, app } = await start();
|
||||||
|
const vae = ez.VAELoader();
|
||||||
|
const latent = ez.EmptyLatentImage();
|
||||||
|
const decode = ez.VAEDecode(latent.outputs.LATENT, vae.outputs.VAE);
|
||||||
|
const scale = ez.ImageScale(decode.outputs.IMAGE);
|
||||||
|
ez.PreviewImage(scale.outputs.IMAGE);
|
||||||
|
|
||||||
|
const r1 = ez.Reroute();
|
||||||
|
const r2 = ez.Reroute();
|
||||||
|
|
||||||
|
latent.widgets.width.value = 64;
|
||||||
|
latent.widgets.height.value = 128;
|
||||||
|
|
||||||
|
latent.widgets.width.convertToInput();
|
||||||
|
latent.widgets.height.convertToInput();
|
||||||
|
latent.widgets.batch_size.convertToInput();
|
||||||
|
|
||||||
|
scale.widgets.width.convertToInput();
|
||||||
|
scale.widgets.height.convertToInput();
|
||||||
|
|
||||||
|
r1.inputs[0].input.label = "hbw";
|
||||||
|
r1.outputs[0].connectTo(latent.inputs.height);
|
||||||
|
r1.outputs[0].connectTo(latent.inputs.batch_size);
|
||||||
|
r1.outputs[0].connectTo(scale.inputs.width);
|
||||||
|
|
||||||
|
r2.inputs[0].input.label = "wh";
|
||||||
|
r2.outputs[0].connectTo(latent.inputs.width);
|
||||||
|
r2.outputs[0].connectTo(scale.inputs.height);
|
||||||
|
|
||||||
|
const group = await convertToGroup(app, graph, "test", [r1, r2, latent, decode, scale]);
|
||||||
|
|
||||||
|
expect(group.inputs[0].input.type).toBe("VAE");
|
||||||
|
expect(group.inputs[1].input.type).toBe("INT");
|
||||||
|
expect(group.inputs[2].input.type).toBe("INT");
|
||||||
|
|
||||||
|
const p1 = ez.PrimitiveNode();
|
||||||
|
const p2 = ez.PrimitiveNode();
|
||||||
|
p1.outputs[0].connectTo(group.inputs[1]);
|
||||||
|
p2.outputs[0].connectTo(group.inputs[2]);
|
||||||
|
|
||||||
|
expect(p1.widgets.value.widget.options?.min).toBe(16); // width/height min
|
||||||
|
expect(p1.widgets.value.widget.options?.max).toBe(4096); // batch max
|
||||||
|
expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
|
||||||
|
|
||||||
|
expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min
|
||||||
|
expect(p2.widgets.value.widget.options?.max).toBe(8192); // width/height max
|
||||||
|
expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
|
||||||
|
|
||||||
|
expect(p1.widgets.value.value).toBe(128);
|
||||||
|
expect(p2.widgets.value.value).toBe(64);
|
||||||
|
|
||||||
|
p1.widgets.value.value = 16;
|
||||||
|
p2.widgets.value.value = 32;
|
||||||
|
|
||||||
|
await checkBeforeAndAfterReload(graph, async (r) => {
|
||||||
|
const id = (v) => (r ? `${group.id}:` : "") + v;
|
||||||
|
expect((await graph.toPrompt()).output).toStrictEqual({
|
||||||
|
1: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" },
|
||||||
|
[id(2)]: { inputs: { width: 32, height: 16, batch_size: 16 }, class_type: "EmptyLatentImage" },
|
||||||
|
[id(3)]: { inputs: { samples: [id(2), 0], vae: ["1", 0] }, class_type: "VAEDecode" },
|
||||||
|
[id(4)]: {
|
||||||
|
inputs: { upscale_method: "nearest-exact", width: 16, height: 32, crop: "disabled", image: [id(3), 0] },
|
||||||
|
class_type: "ImageScale",
|
||||||
|
},
|
||||||
|
5: { inputs: { images: [id(4), 0] }, class_type: "PreviewImage" },
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
test("converted inputs with linked widgets map values correctly on creation", async () => {
|
||||||
|
const { ez, graph, app } = await start();
|
||||||
|
const k1 = ez.KSampler();
|
||||||
|
const k2 = ez.KSampler();
|
||||||
|
k1.widgets.seed.convertToInput();
|
||||||
|
k2.widgets.seed.convertToInput();
|
||||||
|
|
||||||
|
const rr = ez.Reroute();
|
||||||
|
rr.outputs[0].connectTo(k1.inputs.seed);
|
||||||
|
rr.outputs[0].connectTo(k2.inputs.seed);
|
||||||
|
|
||||||
|
const group = await convertToGroup(app, graph, "test", [k1, k2, rr]);
|
||||||
|
expect(group.widgets.steps.value).toBe(20);
|
||||||
|
expect(group.widgets.cfg.value).toBe(8);
|
||||||
|
expect(group.widgets.scheduler.value).toBe("normal");
|
||||||
|
expect(group.widgets["KSampler steps"].value).toBe(20);
|
||||||
|
expect(group.widgets["KSampler cfg"].value).toBe(8);
|
||||||
|
expect(group.widgets["KSampler scheduler"].value).toBe("normal");
|
||||||
|
});
|
||||||
|
test("allow multiple of the same node type to be added", async () => {
|
||||||
|
const { ez, graph, app } = await start();
|
||||||
|
const nodes = [...Array(10)].map(() => ez.ImageScaleBy());
|
||||||
|
const group = await convertToGroup(app, graph, "test", nodes);
|
||||||
|
expect(group.inputs.length).toBe(10);
|
||||||
|
expect(group.outputs.length).toBe(10);
|
||||||
|
expect(group.widgets.length).toBe(20);
|
||||||
|
expect(group.widgets.map((w) => w.widget.name)).toStrictEqual(
|
||||||
|
[...Array(10)]
|
||||||
|
.map((_, i) => `${i > 0 ? "ImageScaleBy " : ""}${i > 1 ? i + " " : ""}`)
|
||||||
|
.flatMap((p) => [`${p}upscale_method`, `${p}scale_by`])
|
||||||
|
);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@ -1,7 +1,13 @@
|
|||||||
// @ts-check
|
// @ts-check
|
||||||
/// <reference path="../node_modules/@types/jest/index.d.ts" />
|
/// <reference path="../node_modules/@types/jest/index.d.ts" />
|
||||||
|
|
||||||
const { start, makeNodeDef, checkBeforeAndAfterReload, assertNotNullOrUndefined } = require("../utils");
|
const {
|
||||||
|
start,
|
||||||
|
makeNodeDef,
|
||||||
|
checkBeforeAndAfterReload,
|
||||||
|
assertNotNullOrUndefined,
|
||||||
|
createDefaultWorkflow,
|
||||||
|
} = require("../utils");
|
||||||
const lg = require("../utils/litegraph");
|
const lg = require("../utils/litegraph");
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -36,7 +42,7 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWi
|
|||||||
if (controlWidgetCount) {
|
if (controlWidgetCount) {
|
||||||
const controlWidget = primitive.widgets.control_after_generate;
|
const controlWidget = primitive.widgets.control_after_generate;
|
||||||
expect(controlWidget.widget.type).toBe("combo");
|
expect(controlWidget.widget.type).toBe("combo");
|
||||||
if(widgetType === "combo") {
|
if (widgetType === "combo") {
|
||||||
const filterWidget = primitive.widgets.control_filter_list;
|
const filterWidget = primitive.widgets.control_filter_list;
|
||||||
expect(filterWidget.widget.type).toBe("string");
|
expect(filterWidget.widget.type).toBe("string");
|
||||||
}
|
}
|
||||||
@ -308,8 +314,8 @@ describe("widget inputs", () => {
|
|||||||
const { ez } = await start({
|
const { ez } = await start({
|
||||||
mockNodeDefs: {
|
mockNodeDefs: {
|
||||||
...makeNodeDef("TestNode1", {}, [["A", "B"]]),
|
...makeNodeDef("TestNode1", {}, [["A", "B"]]),
|
||||||
...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true}] }),
|
...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true }] }),
|
||||||
...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true}] }),
|
...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true }] }),
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -330,7 +336,7 @@ describe("widget inputs", () => {
|
|||||||
|
|
||||||
const n1 = ez.TestNode1();
|
const n1 = ez.TestNode1();
|
||||||
n1.widgets.example.convertToInput();
|
n1.widgets.example.convertToInput();
|
||||||
const p = ez.PrimitiveNode()
|
const p = ez.PrimitiveNode();
|
||||||
p.outputs[0].connectTo(n1.inputs[0]);
|
p.outputs[0].connectTo(n1.inputs[0]);
|
||||||
|
|
||||||
const value = p.widgets.value;
|
const value = p.widgets.value;
|
||||||
@ -380,7 +386,7 @@ describe("widget inputs", () => {
|
|||||||
// Check random
|
// Check random
|
||||||
control.value = "randomize";
|
control.value = "randomize";
|
||||||
filter.value = "/D/";
|
filter.value = "/D/";
|
||||||
for(let i = 0; i < 100; i++) {
|
for (let i = 0; i < 100; i++) {
|
||||||
control["afterQueued"]();
|
control["afterQueued"]();
|
||||||
expect(value.value === "D" || value.value === "DD").toBeTruthy();
|
expect(value.value === "D" || value.value === "DD").toBeTruthy();
|
||||||
}
|
}
|
||||||
@ -392,4 +398,160 @@ describe("widget inputs", () => {
|
|||||||
control["afterQueued"]();
|
control["afterQueued"]();
|
||||||
expect(value.value).toBe("B");
|
expect(value.value).toBe("B");
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe("reroutes", () => {
|
||||||
|
async function checkOutput(graph, values) {
|
||||||
|
expect((await graph.toPrompt()).output).toStrictEqual({
|
||||||
|
1: { inputs: { ckpt_name: "model1.safetensors" }, class_type: "CheckpointLoaderSimple" },
|
||||||
|
2: { inputs: { text: "positive", clip: ["1", 1] }, class_type: "CLIPTextEncode" },
|
||||||
|
3: { inputs: { text: "negative", clip: ["1", 1] }, class_type: "CLIPTextEncode" },
|
||||||
|
4: {
|
||||||
|
inputs: { width: values.width ?? 512, height: values.height ?? 512, batch_size: values?.batch_size ?? 1 },
|
||||||
|
class_type: "EmptyLatentImage",
|
||||||
|
},
|
||||||
|
5: {
|
||||||
|
inputs: {
|
||||||
|
seed: 0,
|
||||||
|
steps: 20,
|
||||||
|
cfg: 8,
|
||||||
|
sampler_name: "euler",
|
||||||
|
scheduler: values?.scheduler ?? "normal",
|
||||||
|
denoise: 1,
|
||||||
|
model: ["1", 0],
|
||||||
|
positive: ["2", 0],
|
||||||
|
negative: ["3", 0],
|
||||||
|
latent_image: ["4", 0],
|
||||||
|
},
|
||||||
|
class_type: "KSampler",
|
||||||
|
},
|
||||||
|
6: { inputs: { samples: ["5", 0], vae: ["1", 2] }, class_type: "VAEDecode" },
|
||||||
|
7: {
|
||||||
|
inputs: { filename_prefix: values.filename_prefix ?? "ComfyUI", images: ["6", 0] },
|
||||||
|
class_type: "SaveImage",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async function waitForWidget(node) {
|
||||||
|
// widgets are created slightly after the graph is ready
|
||||||
|
// hard to find an exact hook to get these so just wait for them to be ready
|
||||||
|
for (let i = 0; i < 10; i++) {
|
||||||
|
await new Promise((r) => setTimeout(r, 10));
|
||||||
|
if (node.widgets?.value) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
it("can connect primitive via a reroute path to a widget input", async () => {
|
||||||
|
const { ez, graph } = await start();
|
||||||
|
const nodes = createDefaultWorkflow(ez, graph);
|
||||||
|
|
||||||
|
nodes.empty.widgets.width.convertToInput();
|
||||||
|
nodes.sampler.widgets.scheduler.convertToInput();
|
||||||
|
nodes.save.widgets.filename_prefix.convertToInput();
|
||||||
|
|
||||||
|
let widthReroute = ez.Reroute();
|
||||||
|
let schedulerReroute = ez.Reroute();
|
||||||
|
let fileReroute = ez.Reroute();
|
||||||
|
|
||||||
|
let widthNext = widthReroute;
|
||||||
|
let schedulerNext = schedulerReroute;
|
||||||
|
let fileNext = fileReroute;
|
||||||
|
|
||||||
|
for (let i = 0; i < 5; i++) {
|
||||||
|
let next = ez.Reroute();
|
||||||
|
widthNext.outputs[0].connectTo(next.inputs[0]);
|
||||||
|
widthNext = next;
|
||||||
|
|
||||||
|
next = ez.Reroute();
|
||||||
|
schedulerNext.outputs[0].connectTo(next.inputs[0]);
|
||||||
|
schedulerNext = next;
|
||||||
|
|
||||||
|
next = ez.Reroute();
|
||||||
|
fileNext.outputs[0].connectTo(next.inputs[0]);
|
||||||
|
fileNext = next;
|
||||||
|
}
|
||||||
|
|
||||||
|
widthNext.outputs[0].connectTo(nodes.empty.inputs.width);
|
||||||
|
schedulerNext.outputs[0].connectTo(nodes.sampler.inputs.scheduler);
|
||||||
|
fileNext.outputs[0].connectTo(nodes.save.inputs.filename_prefix);
|
||||||
|
|
||||||
|
let widthPrimitive = ez.PrimitiveNode();
|
||||||
|
let schedulerPrimitive = ez.PrimitiveNode();
|
||||||
|
let filePrimitive = ez.PrimitiveNode();
|
||||||
|
|
||||||
|
widthPrimitive.outputs[0].connectTo(widthReroute.inputs[0]);
|
||||||
|
schedulerPrimitive.outputs[0].connectTo(schedulerReroute.inputs[0]);
|
||||||
|
filePrimitive.outputs[0].connectTo(fileReroute.inputs[0]);
|
||||||
|
expect(widthPrimitive.widgets.value.value).toBe(512);
|
||||||
|
widthPrimitive.widgets.value.value = 1024;
|
||||||
|
expect(schedulerPrimitive.widgets.value.value).toBe("normal");
|
||||||
|
schedulerPrimitive.widgets.value.value = "simple";
|
||||||
|
expect(filePrimitive.widgets.value.value).toBe("ComfyUI");
|
||||||
|
filePrimitive.widgets.value.value = "ComfyTest";
|
||||||
|
|
||||||
|
await checkBeforeAndAfterReload(graph, async () => {
|
||||||
|
widthPrimitive = graph.find(widthPrimitive);
|
||||||
|
schedulerPrimitive = graph.find(schedulerPrimitive);
|
||||||
|
filePrimitive = graph.find(filePrimitive);
|
||||||
|
await waitForWidget(filePrimitive);
|
||||||
|
expect(widthPrimitive.widgets.length).toBe(2);
|
||||||
|
expect(schedulerPrimitive.widgets.length).toBe(3);
|
||||||
|
expect(filePrimitive.widgets.length).toBe(1);
|
||||||
|
|
||||||
|
await checkOutput(graph, {
|
||||||
|
width: 1024,
|
||||||
|
scheduler: "simple",
|
||||||
|
filename_prefix: "ComfyTest",
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
it("can connect primitive via a reroute path to multiple widget inputs", async () => {
|
||||||
|
const { ez, graph } = await start();
|
||||||
|
const nodes = createDefaultWorkflow(ez, graph);
|
||||||
|
|
||||||
|
nodes.empty.widgets.width.convertToInput();
|
||||||
|
nodes.empty.widgets.height.convertToInput();
|
||||||
|
nodes.empty.widgets.batch_size.convertToInput();
|
||||||
|
|
||||||
|
let reroute = ez.Reroute();
|
||||||
|
let prevReroute = reroute;
|
||||||
|
for (let i = 0; i < 5; i++) {
|
||||||
|
const next = ez.Reroute();
|
||||||
|
prevReroute.outputs[0].connectTo(next.inputs[0]);
|
||||||
|
prevReroute = next;
|
||||||
|
}
|
||||||
|
|
||||||
|
const r1 = ez.Reroute(prevReroute.outputs[0]);
|
||||||
|
const r2 = ez.Reroute(prevReroute.outputs[0]);
|
||||||
|
const r3 = ez.Reroute(r2.outputs[0]);
|
||||||
|
const r4 = ez.Reroute(r2.outputs[0]);
|
||||||
|
|
||||||
|
r1.outputs[0].connectTo(nodes.empty.inputs.width);
|
||||||
|
r3.outputs[0].connectTo(nodes.empty.inputs.height);
|
||||||
|
r4.outputs[0].connectTo(nodes.empty.inputs.batch_size);
|
||||||
|
|
||||||
|
let primitive = ez.PrimitiveNode();
|
||||||
|
primitive.outputs[0].connectTo(reroute.inputs[0]);
|
||||||
|
expect(primitive.widgets.value.value).toBe(1);
|
||||||
|
primitive.widgets.value.value = 64;
|
||||||
|
|
||||||
|
await checkBeforeAndAfterReload(graph, async (r) => {
|
||||||
|
primitive = graph.find(primitive);
|
||||||
|
await waitForWidget(primitive);
|
||||||
|
|
||||||
|
// Ensure widget configs are merged
|
||||||
|
expect(primitive.widgets.value.widget.options?.min).toBe(16); // width/height min
|
||||||
|
expect(primitive.widgets.value.widget.options?.max).toBe(4096); // batch max
|
||||||
|
expect(primitive.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
|
||||||
|
|
||||||
|
await checkOutput(graph, {
|
||||||
|
width: 64,
|
||||||
|
height: 64,
|
||||||
|
batch_size: 64,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@ -78,6 +78,14 @@ export class EzInput extends EzSlot {
|
|||||||
this.input = input;
|
this.input = input;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
get connection() {
|
||||||
|
const link = this.node.node.inputs?.[this.index]?.link;
|
||||||
|
if (link == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return new EzConnection(this.node.app, this.node.app.graph.links[link]);
|
||||||
|
}
|
||||||
|
|
||||||
disconnect() {
|
disconnect() {
|
||||||
this.node.node.disconnectInput(this.index);
|
this.node.node.disconnectInput(this.index);
|
||||||
}
|
}
|
||||||
@ -117,7 +125,7 @@ export class EzOutput extends EzSlot {
|
|||||||
const inp = input.input;
|
const inp = input.input;
|
||||||
const inName = inp.name || inp.label || inp.type;
|
const inName = inp.name || inp.label || inp.type;
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`Connecting from ${input.node.node.type}[${inName}#${input.index}] -> ${this.node.node.type}[${
|
`Connecting from ${input.node.node.type}#${input.node.id}[${inName}#${input.index}] -> ${this.node.node.type}#${this.node.id}[${
|
||||||
this.output.name ?? this.output.type
|
this.output.name ?? this.output.type
|
||||||
}#${this.index}] failed.`
|
}#${this.index}] failed.`
|
||||||
);
|
);
|
||||||
@ -179,6 +187,7 @@ export class EzWidget {
|
|||||||
|
|
||||||
set value(v) {
|
set value(v) {
|
||||||
this.widget.value = v;
|
this.widget.value = v;
|
||||||
|
this.widget.callback?.call?.(this.widget, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
get isConvertedToInput() {
|
get isConvertedToInput() {
|
||||||
@ -319,7 +328,7 @@ export class EzGraph {
|
|||||||
}
|
}
|
||||||
|
|
||||||
stringify() {
|
stringify() {
|
||||||
return JSON.stringify(this.app.graph.serialize(), undefined, "\t");
|
return JSON.stringify(this.app.graph.serialize(), undefined);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -104,3 +104,12 @@ export function createDefaultWorkflow(ez, graph) {
|
|||||||
|
|
||||||
return { ckpt, pos, neg, empty, sampler, decode, save };
|
return { ckpt, pos, neg, empty, sampler, decode, save };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function getNodeDefs() {
|
||||||
|
const { api } = require("../../web/scripts/api");
|
||||||
|
return api.getNodeDefs();
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getNodeDef(nodeId) {
|
||||||
|
return (await getNodeDefs())[nodeId];
|
||||||
|
}
|
||||||
@ -174,6 +174,11 @@ export class GroupNodeConfig {
|
|||||||
node.index = i;
|
node.index = i;
|
||||||
this.processNode(node, seenInputs, seenOutputs);
|
this.processNode(node, seenInputs, seenOutputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (const p of this.#convertedToProcess) {
|
||||||
|
p();
|
||||||
|
}
|
||||||
|
this.#convertedToProcess = null;
|
||||||
await app.registerNodeDef("workflow/" + this.name, this.nodeDef);
|
await app.registerNodeDef("workflow/" + this.name, this.nodeDef);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -192,7 +197,10 @@ export class GroupNodeConfig {
|
|||||||
if (!this.linksFrom[sourceNodeId]) {
|
if (!this.linksFrom[sourceNodeId]) {
|
||||||
this.linksFrom[sourceNodeId] = {};
|
this.linksFrom[sourceNodeId] = {};
|
||||||
}
|
}
|
||||||
this.linksFrom[sourceNodeId][sourceNodeSlot] = l;
|
if (!this.linksFrom[sourceNodeId][sourceNodeSlot]) {
|
||||||
|
this.linksFrom[sourceNodeId][sourceNodeSlot] = [];
|
||||||
|
}
|
||||||
|
this.linksFrom[sourceNodeId][sourceNodeSlot].push(l);
|
||||||
|
|
||||||
if (!this.linksTo[targetNodeId]) {
|
if (!this.linksTo[targetNodeId]) {
|
||||||
this.linksTo[targetNodeId] = {};
|
this.linksTo[targetNodeId] = {};
|
||||||
@ -230,11 +238,11 @@ export class GroupNodeConfig {
|
|||||||
// Skip as its not linked
|
// Skip as its not linked
|
||||||
if (!linksFrom) return;
|
if (!linksFrom) return;
|
||||||
|
|
||||||
let type = linksFrom["0"][5];
|
let type = linksFrom["0"][0][5];
|
||||||
if (type === "COMBO") {
|
if (type === "COMBO") {
|
||||||
// Use the array items
|
// Use the array items
|
||||||
const source = node.outputs[0].widget.name;
|
const source = node.outputs[0].widget.name;
|
||||||
const fromTypeName = this.nodeData.nodes[linksFrom["0"][2]].type;
|
const fromTypeName = this.nodeData.nodes[linksFrom["0"][0][2]].type;
|
||||||
const fromType = globalDefs[fromTypeName];
|
const fromType = globalDefs[fromTypeName];
|
||||||
const input = fromType.input.required[source] ?? fromType.input.optional[source];
|
const input = fromType.input.required[source] ?? fromType.input.optional[source];
|
||||||
type = input[0];
|
type = input[0];
|
||||||
@ -258,10 +266,33 @@ export class GroupNodeConfig {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let config = {};
|
||||||
let rerouteType = "*";
|
let rerouteType = "*";
|
||||||
if (linksFrom) {
|
if (linksFrom) {
|
||||||
const [, , id, slot] = linksFrom["0"];
|
for (const [, , id, slot] of linksFrom["0"]) {
|
||||||
rerouteType = this.nodeData.nodes[id].inputs[slot].type;
|
const node = this.nodeData.nodes[id];
|
||||||
|
const input = node.inputs[slot];
|
||||||
|
if (rerouteType === "*") {
|
||||||
|
rerouteType = input.type;
|
||||||
|
}
|
||||||
|
if (input.widget) {
|
||||||
|
const targetDef = globalDefs[node.type];
|
||||||
|
const targetWidget =
|
||||||
|
targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name];
|
||||||
|
|
||||||
|
const widget = [targetWidget[0], config];
|
||||||
|
const res = mergeIfValid(
|
||||||
|
{
|
||||||
|
widget,
|
||||||
|
},
|
||||||
|
targetWidget,
|
||||||
|
false,
|
||||||
|
null,
|
||||||
|
widget
|
||||||
|
);
|
||||||
|
config = res?.customConfig ?? config;
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if (linksTo) {
|
} else if (linksTo) {
|
||||||
const [id, slot] = linksTo["0"];
|
const [id, slot] = linksTo["0"];
|
||||||
rerouteType = this.nodeData.nodes[id].outputs[slot].type;
|
rerouteType = this.nodeData.nodes[id].outputs[slot].type;
|
||||||
@ -282,10 +313,11 @@ export class GroupNodeConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
config.forceInput = true;
|
||||||
return {
|
return {
|
||||||
input: {
|
input: {
|
||||||
required: {
|
required: {
|
||||||
[rerouteType]: [rerouteType, {}],
|
[rerouteType]: [rerouteType, config],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
output: [rerouteType],
|
output: [rerouteType],
|
||||||
@ -299,16 +331,17 @@ export class GroupNodeConfig {
|
|||||||
|
|
||||||
getInputConfig(node, inputName, seenInputs, config, extra) {
|
getInputConfig(node, inputName, seenInputs, config, extra) {
|
||||||
let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName;
|
let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName;
|
||||||
|
let key = name;
|
||||||
let prefix = "";
|
let prefix = "";
|
||||||
// Special handling for primitive to include the title if it is set rather than just "value"
|
// Special handling for primitive to include the title if it is set rather than just "value"
|
||||||
if ((node.type === "PrimitiveNode" && node.title) || name in seenInputs) {
|
if ((node.type === "PrimitiveNode" && node.title) || name in seenInputs) {
|
||||||
prefix = `${node.title ?? node.type} `;
|
prefix = `${node.title ?? node.type} `;
|
||||||
name = `${prefix}${inputName}`;
|
key = name = `${prefix}${inputName}`;
|
||||||
if (name in seenInputs) {
|
if (name in seenInputs) {
|
||||||
name = `${prefix}${seenInputs[name]} ${inputName}`;
|
name = `${prefix}${seenInputs[name]} ${inputName}`;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
seenInputs[name] = (seenInputs[name] ?? 1) + 1;
|
seenInputs[key] = (seenInputs[key] ?? 1) + 1;
|
||||||
|
|
||||||
if (inputName === "seed" || inputName === "noise_seed") {
|
if (inputName === "seed" || inputName === "noise_seed") {
|
||||||
if (!extra) extra = {};
|
if (!extra) extra = {};
|
||||||
@ -420,10 +453,18 @@ export class GroupNodeConfig {
|
|||||||
defaultInput: true,
|
defaultInput: true,
|
||||||
});
|
});
|
||||||
this.nodeDef.input.required[name] = config;
|
this.nodeDef.input.required[name] = config;
|
||||||
|
this.newToOldWidgetMap[name] = { node, inputName };
|
||||||
|
|
||||||
|
if (!this.oldToNewWidgetMap[node.index]) {
|
||||||
|
this.oldToNewWidgetMap[node.index] = {};
|
||||||
|
}
|
||||||
|
this.oldToNewWidgetMap[node.index][inputName] = name;
|
||||||
|
|
||||||
inputMap[slots.length + i] = this.inputCount++;
|
inputMap[slots.length + i] = this.inputCount++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#convertedToProcess = [];
|
||||||
processNodeInputs(node, seenInputs, inputs) {
|
processNodeInputs(node, seenInputs, inputs) {
|
||||||
const inputMapping = [];
|
const inputMapping = [];
|
||||||
|
|
||||||
@ -434,7 +475,11 @@ export class GroupNodeConfig {
|
|||||||
const linksTo = this.linksTo[node.index] ?? {};
|
const linksTo = this.linksTo[node.index] ?? {};
|
||||||
const inputMap = (this.oldToNewInputMap[node.index] = {});
|
const inputMap = (this.oldToNewInputMap[node.index] = {});
|
||||||
this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs);
|
this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs);
|
||||||
this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs);
|
|
||||||
|
// Converted inputs have to be processed after all other nodes as they'll be at the end of the list
|
||||||
|
this.#convertedToProcess.push(() =>
|
||||||
|
this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs)
|
||||||
|
);
|
||||||
|
|
||||||
return inputMapping;
|
return inputMapping;
|
||||||
}
|
}
|
||||||
@ -597,11 +642,19 @@ export class GroupNodeHandler {
|
|||||||
const output = this.groupData.newToOldOutputMap[link.origin_slot];
|
const output = this.groupData.newToOldOutputMap[link.origin_slot];
|
||||||
let innerNode = this.innerNodes[output.node.index];
|
let innerNode = this.innerNodes[output.node.index];
|
||||||
let l;
|
let l;
|
||||||
while (innerNode.type === "Reroute") {
|
while (innerNode?.type === "Reroute") {
|
||||||
l = innerNode.getInputLink(0);
|
l = innerNode.getInputLink(0);
|
||||||
innerNode = innerNode.getInputNode(0);
|
innerNode = innerNode.getInputNode(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!innerNode) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (l && GroupNodeHandler.isGroupNode(innerNode)) {
|
||||||
|
return innerNode.updateLink(l);
|
||||||
|
}
|
||||||
|
|
||||||
link.origin_id = innerNode.id;
|
link.origin_id = innerNode.id;
|
||||||
link.origin_slot = l?.origin_slot ?? output.slot;
|
link.origin_slot = l?.origin_slot ?? output.slot;
|
||||||
return link;
|
return link;
|
||||||
@ -665,6 +718,8 @@ export class GroupNodeHandler {
|
|||||||
top = newNode.pos[1];
|
top = newNode.pos[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!newNode.widgets) continue;
|
||||||
|
|
||||||
const map = this.groupData.oldToNewWidgetMap[innerNode.index];
|
const map = this.groupData.oldToNewWidgetMap[innerNode.index];
|
||||||
if (map) {
|
if (map) {
|
||||||
const widgets = Object.keys(map);
|
const widgets = Object.keys(map);
|
||||||
@ -721,7 +776,7 @@ export class GroupNodeHandler {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const reconnectOutputs = () => {
|
const reconnectOutputs = (selectedIds) => {
|
||||||
for (let groupOutputId = 0; groupOutputId < node.outputs?.length; groupOutputId++) {
|
for (let groupOutputId = 0; groupOutputId < node.outputs?.length; groupOutputId++) {
|
||||||
const output = node.outputs[groupOutputId];
|
const output = node.outputs[groupOutputId];
|
||||||
if (!output.links) continue;
|
if (!output.links) continue;
|
||||||
@ -861,7 +916,7 @@ export class GroupNodeHandler {
|
|||||||
if (innerNode.type === "PrimitiveNode") {
|
if (innerNode.type === "PrimitiveNode") {
|
||||||
innerNode.primitiveValue = newValue;
|
innerNode.primitiveValue = newValue;
|
||||||
const primitiveLinked = this.groupData.primitiveToWidget[old.node.index];
|
const primitiveLinked = this.groupData.primitiveToWidget[old.node.index];
|
||||||
for (const linked of primitiveLinked) {
|
for (const linked of primitiveLinked ?? []) {
|
||||||
const node = this.innerNodes[linked.nodeId];
|
const node = this.innerNodes[linked.nodeId];
|
||||||
const widget = node.widgets.find((w) => w.name === linked.inputName);
|
const widget = node.widgets.find((w) => w.name === linked.inputName);
|
||||||
|
|
||||||
@ -870,6 +925,18 @@ export class GroupNodeHandler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
|
} else if (innerNode.type === "Reroute") {
|
||||||
|
const rerouteLinks = this.groupData.linksFrom[old.node.index];
|
||||||
|
for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) {
|
||||||
|
const node = this.innerNodes[targetNodeId];
|
||||||
|
const input = node.inputs[targetSlot];
|
||||||
|
if (input.widget) {
|
||||||
|
const widget = node.widgets?.find((w) => w.name === input.widget.name);
|
||||||
|
if (widget) {
|
||||||
|
widget.value = newValue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const widget = innerNode.widgets?.find((w) => w.name === old.inputName);
|
const widget = innerNode.widgets?.find((w) => w.name === old.inputName);
|
||||||
@ -897,33 +964,58 @@ export class GroupNodeHandler {
|
|||||||
this.node.widgets[targetWidgetIndex + i].value = primitiveNode.widgets[i].value;
|
this.node.widgets[targetWidgetIndex + i].value = primitiveNode.widgets[i].value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
populateReroute(node, nodeId, map) {
|
||||||
|
if (node.type !== "Reroute") return;
|
||||||
|
|
||||||
|
const link = this.groupData.linksFrom[nodeId]?.[0]?.[0];
|
||||||
|
if (!link) return;
|
||||||
|
const [, , targetNodeId, targetNodeSlot] = link;
|
||||||
|
const targetNode = this.groupData.nodeData.nodes[targetNodeId];
|
||||||
|
const inputs = targetNode.inputs;
|
||||||
|
const targetWidget = inputs?.[targetNodeSlot].widget;
|
||||||
|
if (!targetWidget) return;
|
||||||
|
|
||||||
|
const offset = inputs.length - (targetNode.widgets_values?.length ?? 0);
|
||||||
|
const v = targetNode.widgets_values?.[targetNodeSlot - offset];
|
||||||
|
if (v == null) return;
|
||||||
|
|
||||||
|
const widgetName = Object.values(map)[0];
|
||||||
|
const widget = this.node.widgets.find(w => w.name === widgetName);
|
||||||
|
if(widget) {
|
||||||
|
widget.value = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
populateWidgets() {
|
populateWidgets() {
|
||||||
|
if (!this.node.widgets) return;
|
||||||
|
|
||||||
for (let nodeId = 0; nodeId < this.groupData.nodeData.nodes.length; nodeId++) {
|
for (let nodeId = 0; nodeId < this.groupData.nodeData.nodes.length; nodeId++) {
|
||||||
const node = this.groupData.nodeData.nodes[nodeId];
|
const node = this.groupData.nodeData.nodes[nodeId];
|
||||||
|
const map = this.groupData.oldToNewWidgetMap[nodeId] ?? {};
|
||||||
if (!node.widgets_values?.length) continue;
|
|
||||||
|
|
||||||
const map = this.groupData.oldToNewWidgetMap[nodeId];
|
|
||||||
const widgets = Object.keys(map);
|
const widgets = Object.keys(map);
|
||||||
|
|
||||||
|
if (!node.widgets_values?.length) {
|
||||||
|
// special handling for populating values into reroutes
|
||||||
|
// this allows primitives connect to them to pick up the correct value
|
||||||
|
this.populateReroute(node, nodeId, map);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
let linkedShift = 0;
|
let linkedShift = 0;
|
||||||
for (let i = 0; i < widgets.length; i++) {
|
for (let i = 0; i < widgets.length; i++) {
|
||||||
const oldName = widgets[i];
|
const oldName = widgets[i];
|
||||||
const newName = map[oldName];
|
const newName = map[oldName];
|
||||||
const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName);
|
const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName);
|
||||||
const mainWidget = this.node.widgets[widgetIndex];
|
const mainWidget = this.node.widgets[widgetIndex];
|
||||||
if (!newName) {
|
if (this.populatePrimitive(node, nodeId, oldName, i, linkedShift) || widgetIndex === -1) {
|
||||||
// New name will be null if its a converted widget
|
|
||||||
this.populatePrimitive(node, nodeId, oldName, i, linkedShift);
|
|
||||||
|
|
||||||
// Find the inner widget and shift by the number of linked widgets as they will have been removed too
|
// Find the inner widget and shift by the number of linked widgets as they will have been removed too
|
||||||
const innerWidget = this.innerNodes[nodeId].widgets?.find((w) => w.name === oldName);
|
const innerWidget = this.innerNodes[nodeId].widgets?.find((w) => w.name === oldName);
|
||||||
linkedShift += innerWidget.linkedWidgets?.length ?? 0;
|
linkedShift += innerWidget?.linkedWidgets?.length ?? 0;
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (widgetIndex === -1) {
|
if (widgetIndex === -1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -33,6 +33,18 @@ function loadedImageToBlob(image) {
|
|||||||
return blob;
|
return blob;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function loadImage(imagePath) {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
const image = new Image();
|
||||||
|
|
||||||
|
image.onload = function() {
|
||||||
|
resolve(image);
|
||||||
|
};
|
||||||
|
|
||||||
|
image.src = imagePath;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
async function uploadMask(filepath, formData) {
|
async function uploadMask(filepath, formData) {
|
||||||
await api.fetchApi('/upload/mask', {
|
await api.fetchApi('/upload/mask', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
@ -50,25 +62,25 @@ async function uploadMask(filepath, formData) {
|
|||||||
ClipspaceDialog.invalidatePreview();
|
ClipspaceDialog.invalidatePreview();
|
||||||
}
|
}
|
||||||
|
|
||||||
function prepareRGB(image, backupCanvas, backupCtx) {
|
function prepare_mask(image, maskCanvas, maskCtx) {
|
||||||
// paste mask data into alpha channel
|
// paste mask data into alpha channel
|
||||||
backupCtx.drawImage(image, 0, 0, backupCanvas.width, backupCanvas.height);
|
maskCtx.drawImage(image, 0, 0, maskCanvas.width, maskCanvas.height);
|
||||||
const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height);
|
const maskData = maskCtx.getImageData(0, 0, maskCanvas.width, maskCanvas.height);
|
||||||
|
|
||||||
// refine mask image
|
// invert mask
|
||||||
for (let i = 0; i < backupData.data.length; i += 4) {
|
for (let i = 0; i < maskData.data.length; i += 4) {
|
||||||
if(backupData.data[i+3] == 255)
|
if(maskData.data[i+3] == 255)
|
||||||
backupData.data[i+3] = 0;
|
maskData.data[i+3] = 0;
|
||||||
else
|
else
|
||||||
backupData.data[i+3] = 255;
|
maskData.data[i+3] = 255;
|
||||||
|
|
||||||
backupData.data[i] = 0;
|
maskData.data[i] = 0;
|
||||||
backupData.data[i+1] = 0;
|
maskData.data[i+1] = 0;
|
||||||
backupData.data[i+2] = 0;
|
maskData.data[i+2] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
backupCtx.globalCompositeOperation = 'source-over';
|
maskCtx.globalCompositeOperation = 'source-over';
|
||||||
backupCtx.putImageData(backupData, 0, 0);
|
maskCtx.putImageData(maskData, 0, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
class MaskEditorDialog extends ComfyDialog {
|
class MaskEditorDialog extends ComfyDialog {
|
||||||
@ -155,10 +167,6 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
|
|
||||||
// If it is specified as relative, using it only as a hidden placeholder for padding is recommended
|
// If it is specified as relative, using it only as a hidden placeholder for padding is recommended
|
||||||
// to prevent anomalies where it exceeds a certain size and goes outside of the window.
|
// to prevent anomalies where it exceeds a certain size and goes outside of the window.
|
||||||
var placeholder = document.createElement("div");
|
|
||||||
placeholder.style.position = "relative";
|
|
||||||
placeholder.style.height = "50px";
|
|
||||||
|
|
||||||
var bottom_panel = document.createElement("div");
|
var bottom_panel = document.createElement("div");
|
||||||
bottom_panel.style.position = "absolute";
|
bottom_panel.style.position = "absolute";
|
||||||
bottom_panel.style.bottom = "0px";
|
bottom_panel.style.bottom = "0px";
|
||||||
@ -180,18 +188,16 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
this.brush = brush;
|
this.brush = brush;
|
||||||
this.element.appendChild(imgCanvas);
|
this.element.appendChild(imgCanvas);
|
||||||
this.element.appendChild(maskCanvas);
|
this.element.appendChild(maskCanvas);
|
||||||
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
|
|
||||||
this.element.appendChild(bottom_panel);
|
this.element.appendChild(bottom_panel);
|
||||||
document.body.appendChild(brush);
|
document.body.appendChild(brush);
|
||||||
|
|
||||||
var brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => {
|
this.brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => {
|
||||||
self.brush_size = event.target.value;
|
self.brush_size = event.target.value;
|
||||||
self.updateBrushPreview(self, null, null);
|
self.updateBrushPreview(self, null, null);
|
||||||
});
|
});
|
||||||
var clearButton = this.createLeftButton("Clear",
|
var clearButton = this.createLeftButton("Clear",
|
||||||
() => {
|
() => {
|
||||||
self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height);
|
self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height);
|
||||||
self.backupCtx.clearRect(0, 0, self.backupCanvas.width, self.backupCanvas.height);
|
|
||||||
});
|
});
|
||||||
var cancelButton = this.createRightButton("Cancel", () => {
|
var cancelButton = this.createRightButton("Cancel", () => {
|
||||||
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
|
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
|
||||||
@ -207,40 +213,42 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
|
|
||||||
this.element.appendChild(imgCanvas);
|
this.element.appendChild(imgCanvas);
|
||||||
this.element.appendChild(maskCanvas);
|
this.element.appendChild(maskCanvas);
|
||||||
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
|
|
||||||
this.element.appendChild(bottom_panel);
|
this.element.appendChild(bottom_panel);
|
||||||
|
|
||||||
bottom_panel.appendChild(clearButton);
|
bottom_panel.appendChild(clearButton);
|
||||||
bottom_panel.appendChild(this.saveButton);
|
bottom_panel.appendChild(this.saveButton);
|
||||||
bottom_panel.appendChild(cancelButton);
|
bottom_panel.appendChild(cancelButton);
|
||||||
bottom_panel.appendChild(brush_size_slider);
|
bottom_panel.appendChild(this.brush_size_slider);
|
||||||
|
|
||||||
|
imgCanvas.style.position = "absolute";
|
||||||
|
maskCanvas.style.position = "absolute";
|
||||||
|
|
||||||
imgCanvas.style.position = "relative";
|
|
||||||
imgCanvas.style.top = "200";
|
imgCanvas.style.top = "200";
|
||||||
imgCanvas.style.left = "0";
|
imgCanvas.style.left = "0";
|
||||||
|
|
||||||
maskCanvas.style.position = "absolute";
|
maskCanvas.style.top = imgCanvas.style.top;
|
||||||
|
maskCanvas.style.left = imgCanvas.style.left;
|
||||||
}
|
}
|
||||||
|
|
||||||
show() {
|
async show() {
|
||||||
|
this.zoom_ratio = 1.0;
|
||||||
|
this.pan_x = 0;
|
||||||
|
this.pan_y = 0;
|
||||||
|
|
||||||
if(!this.is_layout_created) {
|
if(!this.is_layout_created) {
|
||||||
// layout
|
// layout
|
||||||
const imgCanvas = document.createElement('canvas');
|
const imgCanvas = document.createElement('canvas');
|
||||||
const maskCanvas = document.createElement('canvas');
|
const maskCanvas = document.createElement('canvas');
|
||||||
const backupCanvas = document.createElement('canvas');
|
|
||||||
|
|
||||||
imgCanvas.id = "imageCanvas";
|
imgCanvas.id = "imageCanvas";
|
||||||
maskCanvas.id = "maskCanvas";
|
maskCanvas.id = "maskCanvas";
|
||||||
backupCanvas.id = "backupCanvas";
|
|
||||||
|
|
||||||
this.setlayout(imgCanvas, maskCanvas);
|
this.setlayout(imgCanvas, maskCanvas);
|
||||||
|
|
||||||
// prepare content
|
// prepare content
|
||||||
this.imgCanvas = imgCanvas;
|
this.imgCanvas = imgCanvas;
|
||||||
this.maskCanvas = maskCanvas;
|
this.maskCanvas = maskCanvas;
|
||||||
this.backupCanvas = backupCanvas;
|
this.maskCtx = maskCanvas.getContext('2d', {willReadFrequently: true });
|
||||||
this.maskCtx = maskCanvas.getContext('2d');
|
|
||||||
this.backupCtx = backupCanvas.getContext('2d');
|
|
||||||
|
|
||||||
this.setEventHandler(maskCanvas);
|
this.setEventHandler(maskCanvas);
|
||||||
|
|
||||||
@ -252,6 +260,8 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
mutations.forEach(function(mutation) {
|
mutations.forEach(function(mutation) {
|
||||||
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
|
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
|
||||||
if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') {
|
if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') {
|
||||||
|
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
|
||||||
|
self.brush.style.display = "none";
|
||||||
ComfyApp.onClipspaceEditorClosed();
|
ComfyApp.onClipspaceEditorClosed();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -264,7 +274,8 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
observer.observe(this.element, config);
|
observer.observe(this.element, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
this.setImages(this.imgCanvas, this.backupCanvas);
|
// The keydown event needs to be reconfigured when closing the dialog as it gets removed.
|
||||||
|
document.addEventListener('keydown', MaskEditorDialog.handleKeyDown);
|
||||||
|
|
||||||
if(ComfyApp.clipspace_return_node) {
|
if(ComfyApp.clipspace_return_node) {
|
||||||
this.saveButton.innerText = "Save to node";
|
this.saveButton.innerText = "Save to node";
|
||||||
@ -275,97 +286,157 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
this.saveButton.disabled = false;
|
this.saveButton.disabled = false;
|
||||||
|
|
||||||
this.element.style.display = "block";
|
this.element.style.display = "block";
|
||||||
|
this.element.style.width = "85%";
|
||||||
|
this.element.style.margin = "0 7.5%";
|
||||||
|
this.element.style.height = "100vh";
|
||||||
|
this.element.style.top = "50%";
|
||||||
|
this.element.style.left = "42%";
|
||||||
this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority.
|
this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority.
|
||||||
|
|
||||||
|
await this.setImages(this.imgCanvas);
|
||||||
|
|
||||||
|
this.is_visible = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
isOpened() {
|
isOpened() {
|
||||||
return this.element.style.display == "block";
|
return this.element.style.display == "block";
|
||||||
}
|
}
|
||||||
|
|
||||||
setImages(imgCanvas, backupCanvas) {
|
invalidateCanvas(orig_image, mask_image) {
|
||||||
const imgCtx = imgCanvas.getContext('2d');
|
this.imgCanvas.width = orig_image.width;
|
||||||
const backupCtx = backupCanvas.getContext('2d');
|
this.imgCanvas.height = orig_image.height;
|
||||||
|
|
||||||
|
this.maskCanvas.width = orig_image.width;
|
||||||
|
this.maskCanvas.height = orig_image.height;
|
||||||
|
|
||||||
|
let imgCtx = this.imgCanvas.getContext('2d', {willReadFrequently: true });
|
||||||
|
let maskCtx = this.maskCanvas.getContext('2d', {willReadFrequently: true });
|
||||||
|
|
||||||
|
imgCtx.drawImage(orig_image, 0, 0, orig_image.width, orig_image.height);
|
||||||
|
prepare_mask(mask_image, this.maskCanvas, maskCtx);
|
||||||
|
}
|
||||||
|
|
||||||
|
async setImages(imgCanvas) {
|
||||||
|
let self = this;
|
||||||
|
|
||||||
|
const imgCtx = imgCanvas.getContext('2d', {willReadFrequently: true });
|
||||||
const maskCtx = this.maskCtx;
|
const maskCtx = this.maskCtx;
|
||||||
const maskCanvas = this.maskCanvas;
|
const maskCanvas = this.maskCanvas;
|
||||||
|
|
||||||
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
|
|
||||||
imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height);
|
imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height);
|
||||||
maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height);
|
maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height);
|
||||||
|
|
||||||
// image load
|
// image load
|
||||||
const orig_image = new Image();
|
|
||||||
window.addEventListener("resize", () => {
|
|
||||||
// repositioning
|
|
||||||
imgCanvas.width = window.innerWidth - 250;
|
|
||||||
imgCanvas.height = window.innerHeight - 200;
|
|
||||||
|
|
||||||
// redraw image
|
|
||||||
let drawWidth = orig_image.width;
|
|
||||||
let drawHeight = orig_image.height;
|
|
||||||
if (orig_image.width > imgCanvas.width) {
|
|
||||||
drawWidth = imgCanvas.width;
|
|
||||||
drawHeight = (drawWidth / orig_image.width) * orig_image.height;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (drawHeight > imgCanvas.height) {
|
|
||||||
drawHeight = imgCanvas.height;
|
|
||||||
drawWidth = (drawHeight / orig_image.height) * orig_image.width;
|
|
||||||
}
|
|
||||||
|
|
||||||
imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight);
|
|
||||||
|
|
||||||
// update mask
|
|
||||||
maskCanvas.width = drawWidth;
|
|
||||||
maskCanvas.height = drawHeight;
|
|
||||||
maskCanvas.style.top = imgCanvas.offsetTop + "px";
|
|
||||||
maskCanvas.style.left = imgCanvas.offsetLeft + "px";
|
|
||||||
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
|
|
||||||
maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height);
|
|
||||||
});
|
|
||||||
|
|
||||||
const filepath = ComfyApp.clipspace.images;
|
const filepath = ComfyApp.clipspace.images;
|
||||||
|
|
||||||
const touched_image = new Image();
|
|
||||||
|
|
||||||
touched_image.onload = function() {
|
|
||||||
backupCanvas.width = touched_image.width;
|
|
||||||
backupCanvas.height = touched_image.height;
|
|
||||||
|
|
||||||
prepareRGB(touched_image, backupCanvas, backupCtx);
|
|
||||||
};
|
|
||||||
|
|
||||||
const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src)
|
const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src)
|
||||||
alpha_url.searchParams.delete('channel');
|
alpha_url.searchParams.delete('channel');
|
||||||
alpha_url.searchParams.delete('preview');
|
alpha_url.searchParams.delete('preview');
|
||||||
alpha_url.searchParams.set('channel', 'a');
|
alpha_url.searchParams.set('channel', 'a');
|
||||||
touched_image.src = alpha_url;
|
let mask_image = await loadImage(alpha_url);
|
||||||
|
|
||||||
// original image load
|
// original image load
|
||||||
orig_image.onload = function() {
|
|
||||||
window.dispatchEvent(new Event('resize'));
|
|
||||||
};
|
|
||||||
|
|
||||||
const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src);
|
const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src);
|
||||||
rgb_url.searchParams.delete('channel');
|
rgb_url.searchParams.delete('channel');
|
||||||
rgb_url.searchParams.set('channel', 'rgb');
|
rgb_url.searchParams.set('channel', 'rgb');
|
||||||
orig_image.src = rgb_url;
|
this.image = new Image();
|
||||||
this.image = orig_image;
|
this.image.onload = function() {
|
||||||
|
maskCanvas.width = self.image.width;
|
||||||
|
maskCanvas.height = self.image.height;
|
||||||
|
|
||||||
|
self.invalidateCanvas(self.image, mask_image);
|
||||||
|
self.initializeCanvasPanZoom();
|
||||||
|
};
|
||||||
|
this.image.src = rgb_url;
|
||||||
}
|
}
|
||||||
|
|
||||||
setEventHandler(maskCanvas) {
|
initializeCanvasPanZoom() {
|
||||||
maskCanvas.addEventListener("contextmenu", (event) => {
|
// set initialize
|
||||||
event.preventDefault();
|
let drawWidth = this.image.width;
|
||||||
});
|
let drawHeight = this.image.height;
|
||||||
|
|
||||||
|
let width = this.element.clientWidth;
|
||||||
|
let height = this.element.clientHeight;
|
||||||
|
|
||||||
|
if (this.image.width > width) {
|
||||||
|
drawWidth = width;
|
||||||
|
drawHeight = (drawWidth / this.image.width) * this.image.height;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (drawHeight > height) {
|
||||||
|
drawHeight = height;
|
||||||
|
drawWidth = (drawHeight / this.image.height) * this.image.width;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.zoom_ratio = drawWidth/this.image.width;
|
||||||
|
|
||||||
|
const canvasX = (width - drawWidth) / 2;
|
||||||
|
const canvasY = (height - drawHeight) / 2;
|
||||||
|
this.pan_x = canvasX;
|
||||||
|
this.pan_y = canvasY;
|
||||||
|
|
||||||
|
this.invalidatePanZoom();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
invalidatePanZoom() {
|
||||||
|
let raw_width = this.image.width * this.zoom_ratio;
|
||||||
|
let raw_height = this.image.height * this.zoom_ratio;
|
||||||
|
|
||||||
|
if(this.pan_x + raw_width < 10) {
|
||||||
|
this.pan_x = 10 - raw_width;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(this.pan_y + raw_height < 10) {
|
||||||
|
this.pan_y = 10 - raw_height;
|
||||||
|
}
|
||||||
|
|
||||||
|
let width = `${raw_width}px`;
|
||||||
|
let height = `${raw_height}px`;
|
||||||
|
|
||||||
|
let left = `${this.pan_x}px`;
|
||||||
|
let top = `${this.pan_y}px`;
|
||||||
|
|
||||||
|
this.maskCanvas.style.width = width;
|
||||||
|
this.maskCanvas.style.height = height;
|
||||||
|
this.maskCanvas.style.left = left;
|
||||||
|
this.maskCanvas.style.top = top;
|
||||||
|
|
||||||
|
this.imgCanvas.style.width = width;
|
||||||
|
this.imgCanvas.style.height = height;
|
||||||
|
this.imgCanvas.style.left = left;
|
||||||
|
this.imgCanvas.style.top = top;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
setEventHandler(maskCanvas) {
|
||||||
const self = this;
|
const self = this;
|
||||||
maskCanvas.addEventListener('wheel', (event) => this.handleWheelEvent(self,event));
|
|
||||||
maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event));
|
if(!this.handler_registered) {
|
||||||
document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp);
|
maskCanvas.addEventListener("contextmenu", (event) => {
|
||||||
maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event));
|
event.preventDefault();
|
||||||
maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event));
|
});
|
||||||
maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; });
|
|
||||||
maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; });
|
this.element.addEventListener('wheel', (event) => this.handleWheelEvent(self,event));
|
||||||
document.addEventListener('keydown', MaskEditorDialog.handleKeyDown);
|
this.element.addEventListener('pointermove', (event) => this.pointMoveEvent(self,event));
|
||||||
|
this.element.addEventListener('touchmove', (event) => this.pointMoveEvent(self,event));
|
||||||
|
|
||||||
|
this.element.addEventListener('dragstart', (event) => {
|
||||||
|
if(event.ctrlKey) {
|
||||||
|
event.preventDefault();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event));
|
||||||
|
maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event));
|
||||||
|
maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event));
|
||||||
|
maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; });
|
||||||
|
maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; });
|
||||||
|
|
||||||
|
document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp);
|
||||||
|
|
||||||
|
this.handler_registered = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
brush_size = 10;
|
brush_size = 10;
|
||||||
@ -378,8 +449,10 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
const self = MaskEditorDialog.instance;
|
const self = MaskEditorDialog.instance;
|
||||||
if (event.key === ']') {
|
if (event.key === ']') {
|
||||||
self.brush_size = Math.min(self.brush_size+2, 100);
|
self.brush_size = Math.min(self.brush_size+2, 100);
|
||||||
|
self.brush_slider_input.value = self.brush_size;
|
||||||
} else if (event.key === '[') {
|
} else if (event.key === '[') {
|
||||||
self.brush_size = Math.max(self.brush_size-2, 1);
|
self.brush_size = Math.max(self.brush_size-2, 1);
|
||||||
|
self.brush_slider_input.value = self.brush_size;
|
||||||
} else if(event.key === 'Enter') {
|
} else if(event.key === 'Enter') {
|
||||||
self.save();
|
self.save();
|
||||||
}
|
}
|
||||||
@ -389,6 +462,10 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
|
|
||||||
static handlePointerUp(event) {
|
static handlePointerUp(event) {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
|
|
||||||
|
this.mousedown_x = null;
|
||||||
|
this.mousedown_y = null;
|
||||||
|
|
||||||
MaskEditorDialog.instance.drawing_mode = false;
|
MaskEditorDialog.instance.drawing_mode = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -398,24 +475,70 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
var centerX = self.cursorX;
|
var centerX = self.cursorX;
|
||||||
var centerY = self.cursorY;
|
var centerY = self.cursorY;
|
||||||
|
|
||||||
brush.style.width = self.brush_size * 2 + "px";
|
brush.style.width = self.brush_size * 2 * this.zoom_ratio + "px";
|
||||||
brush.style.height = self.brush_size * 2 + "px";
|
brush.style.height = self.brush_size * 2 * this.zoom_ratio + "px";
|
||||||
brush.style.left = (centerX - self.brush_size) + "px";
|
brush.style.left = (centerX - self.brush_size * this.zoom_ratio) + "px";
|
||||||
brush.style.top = (centerY - self.brush_size) + "px";
|
brush.style.top = (centerY - self.brush_size * this.zoom_ratio) + "px";
|
||||||
}
|
}
|
||||||
|
|
||||||
handleWheelEvent(self, event) {
|
handleWheelEvent(self, event) {
|
||||||
if(event.deltaY < 0)
|
event.preventDefault();
|
||||||
self.brush_size = Math.min(self.brush_size+2, 100);
|
|
||||||
else
|
|
||||||
self.brush_size = Math.max(self.brush_size-2, 1);
|
|
||||||
|
|
||||||
self.brush_slider_input.value = self.brush_size;
|
if(event.ctrlKey) {
|
||||||
|
// zoom canvas
|
||||||
|
if(event.deltaY < 0) {
|
||||||
|
this.zoom_ratio = Math.min(10.0, this.zoom_ratio+0.2);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
this.zoom_ratio = Math.max(0.2, this.zoom_ratio-0.2);
|
||||||
|
}
|
||||||
|
|
||||||
|
this.invalidatePanZoom();
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// adjust brush size
|
||||||
|
if(event.deltaY < 0)
|
||||||
|
this.brush_size = Math.min(this.brush_size+2, 100);
|
||||||
|
else
|
||||||
|
this.brush_size = Math.max(this.brush_size-2, 1);
|
||||||
|
|
||||||
|
this.brush_slider_input.value = this.brush_size;
|
||||||
|
|
||||||
|
this.updateBrushPreview(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pointMoveEvent(self, event) {
|
||||||
|
this.cursorX = event.pageX;
|
||||||
|
this.cursorY = event.pageY;
|
||||||
|
|
||||||
self.updateBrushPreview(self);
|
self.updateBrushPreview(self);
|
||||||
|
|
||||||
|
if(event.ctrlKey) {
|
||||||
|
event.preventDefault();
|
||||||
|
self.pan_move(self, event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pan_move(self, event) {
|
||||||
|
if(event.buttons == 1) {
|
||||||
|
if(this.mousedown_x) {
|
||||||
|
let deltaX = this.mousedown_x - event.clientX;
|
||||||
|
let deltaY = this.mousedown_y - event.clientY;
|
||||||
|
|
||||||
|
self.pan_x = this.mousedown_pan_x - deltaX;
|
||||||
|
self.pan_y = this.mousedown_pan_y - deltaY;
|
||||||
|
|
||||||
|
self.invalidatePanZoom();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
draw_move(self, event) {
|
draw_move(self, event) {
|
||||||
|
if(event.ctrlKey) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
|
|
||||||
this.cursorX = event.pageX;
|
this.cursorX = event.pageX;
|
||||||
@ -439,6 +562,9 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
y = event.targetTouches[0].clientY - maskRect.top;
|
y = event.targetTouches[0].clientY - maskRect.top;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
x /= self.zoom_ratio;
|
||||||
|
y /= self.zoom_ratio;
|
||||||
|
|
||||||
var brush_size = this.brush_size;
|
var brush_size = this.brush_size;
|
||||||
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
||||||
brush_size *= event.pressure;
|
brush_size *= event.pressure;
|
||||||
@ -489,8 +615,8 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
}
|
}
|
||||||
else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) {
|
else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) {
|
||||||
const maskRect = self.maskCanvas.getBoundingClientRect();
|
const maskRect = self.maskCanvas.getBoundingClientRect();
|
||||||
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
|
const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio;
|
||||||
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
|
const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio;
|
||||||
|
|
||||||
var brush_size = this.brush_size;
|
var brush_size = this.brush_size;
|
||||||
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
||||||
@ -540,6 +666,17 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
}
|
}
|
||||||
|
|
||||||
handlePointerDown(self, event) {
|
handlePointerDown(self, event) {
|
||||||
|
if(event.ctrlKey) {
|
||||||
|
if (event.buttons == 1) {
|
||||||
|
this.mousedown_x = event.clientX;
|
||||||
|
this.mousedown_y = event.clientY;
|
||||||
|
|
||||||
|
this.mousedown_pan_x = this.pan_x;
|
||||||
|
this.mousedown_pan_y = this.pan_y;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
var brush_size = this.brush_size;
|
var brush_size = this.brush_size;
|
||||||
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
if(event instanceof PointerEvent && event.pointerType == 'pen') {
|
||||||
brush_size *= event.pressure;
|
brush_size *= event.pressure;
|
||||||
@ -551,8 +688,8 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
|
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
const maskRect = self.maskCanvas.getBoundingClientRect();
|
const maskRect = self.maskCanvas.getBoundingClientRect();
|
||||||
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
|
const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio;
|
||||||
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
|
const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio;
|
||||||
|
|
||||||
self.maskCtx.beginPath();
|
self.maskCtx.beginPath();
|
||||||
if (event.button == 0) {
|
if (event.button == 0) {
|
||||||
@ -570,15 +707,18 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async save() {
|
async save() {
|
||||||
const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true});
|
const backupCanvas = document.createElement('canvas');
|
||||||
|
const backupCtx = backupCanvas.getContext('2d', {willReadFrequently:true});
|
||||||
|
backupCanvas.width = this.image.width;
|
||||||
|
backupCanvas.height = this.image.height;
|
||||||
|
|
||||||
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
|
backupCtx.clearRect(0,0, backupCanvas.width, backupCanvas.height);
|
||||||
backupCtx.drawImage(this.maskCanvas,
|
backupCtx.drawImage(this.maskCanvas,
|
||||||
0, 0, this.maskCanvas.width, this.maskCanvas.height,
|
0, 0, this.maskCanvas.width, this.maskCanvas.height,
|
||||||
0, 0, this.backupCanvas.width, this.backupCanvas.height);
|
0, 0, backupCanvas.width, backupCanvas.height);
|
||||||
|
|
||||||
// paste mask data into alpha channel
|
// paste mask data into alpha channel
|
||||||
const backupData = backupCtx.getImageData(0, 0, this.backupCanvas.width, this.backupCanvas.height);
|
const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height);
|
||||||
|
|
||||||
// refine mask image
|
// refine mask image
|
||||||
for (let i = 0; i < backupData.data.length; i += 4) {
|
for (let i = 0; i < backupData.data.length; i += 4) {
|
||||||
@ -615,7 +755,7 @@ class MaskEditorDialog extends ComfyDialog {
|
|||||||
ComfyApp.clipspace.widgets[index].value = item;
|
ComfyApp.clipspace.widgets[index].value = item;
|
||||||
}
|
}
|
||||||
|
|
||||||
const dataURL = this.backupCanvas.toDataURL();
|
const dataURL = backupCanvas.toDataURL();
|
||||||
const blob = dataURLToBlob(dataURL);
|
const blob = dataURLToBlob(dataURL);
|
||||||
|
|
||||||
let original_url = new URL(this.image.src);
|
let original_url = new URL(this.image.src);
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
import { app } from "../../scripts/app.js";
|
import { app } from "../../scripts/app.js";
|
||||||
|
import { mergeIfValid, getWidgetConfig, setWidgetConfig } from "./widgetInputs.js";
|
||||||
|
|
||||||
// Node that allows you to redirect connections for cleaner graphs
|
// Node that allows you to redirect connections for cleaner graphs
|
||||||
|
|
||||||
app.registerExtension({
|
app.registerExtension({
|
||||||
name: "Comfy.RerouteNode",
|
name: "Comfy.RerouteNode",
|
||||||
registerCustomNodes() {
|
registerCustomNodes(app) {
|
||||||
class RerouteNode {
|
class RerouteNode {
|
||||||
constructor() {
|
constructor() {
|
||||||
if (!this.properties) {
|
if (!this.properties) {
|
||||||
@ -16,6 +17,12 @@ app.registerExtension({
|
|||||||
this.addInput("", "*");
|
this.addInput("", "*");
|
||||||
this.addOutput(this.properties.showOutputText ? "*" : "", "*");
|
this.addOutput(this.properties.showOutputText ? "*" : "", "*");
|
||||||
|
|
||||||
|
this.onAfterGraphConfigured = function () {
|
||||||
|
requestAnimationFrame(() => {
|
||||||
|
this.onConnectionsChange(LiteGraph.INPUT, null, true, null);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
this.onConnectionsChange = function (type, index, connected, link_info) {
|
this.onConnectionsChange = function (type, index, connected, link_info) {
|
||||||
this.applyOrientation();
|
this.applyOrientation();
|
||||||
|
|
||||||
@ -47,6 +54,7 @@ app.registerExtension({
|
|||||||
const linkId = currentNode.inputs[0].link;
|
const linkId = currentNode.inputs[0].link;
|
||||||
if (linkId !== null) {
|
if (linkId !== null) {
|
||||||
const link = app.graph.links[linkId];
|
const link = app.graph.links[linkId];
|
||||||
|
if (!link) return;
|
||||||
const node = app.graph.getNodeById(link.origin_id);
|
const node = app.graph.getNodeById(link.origin_id);
|
||||||
const type = node.constructor.type;
|
const type = node.constructor.type;
|
||||||
if (type === "Reroute") {
|
if (type === "Reroute") {
|
||||||
@ -54,8 +62,7 @@ app.registerExtension({
|
|||||||
// We've found a circle
|
// We've found a circle
|
||||||
currentNode.disconnectInput(link.target_slot);
|
currentNode.disconnectInput(link.target_slot);
|
||||||
currentNode = null;
|
currentNode = null;
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
// Move the previous node
|
// Move the previous node
|
||||||
currentNode = node;
|
currentNode = node;
|
||||||
}
|
}
|
||||||
@ -94,8 +101,11 @@ app.registerExtension({
|
|||||||
updateNodes.push(node);
|
updateNodes.push(node);
|
||||||
} else {
|
} else {
|
||||||
// We've found an output
|
// We've found an output
|
||||||
const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null;
|
const nodeOutType =
|
||||||
if (inputType && nodeOutType !== inputType) {
|
node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type
|
||||||
|
? node.inputs[link.target_slot].type
|
||||||
|
: null;
|
||||||
|
if (inputType && inputType !== "*" && nodeOutType !== inputType) {
|
||||||
// The output doesnt match our input so disconnect it
|
// The output doesnt match our input so disconnect it
|
||||||
node.disconnectInput(link.target_slot);
|
node.disconnectInput(link.target_slot);
|
||||||
} else {
|
} else {
|
||||||
@ -111,6 +121,9 @@ app.registerExtension({
|
|||||||
const displayType = inputType || outputType || "*";
|
const displayType = inputType || outputType || "*";
|
||||||
const color = LGraphCanvas.link_type_colors[displayType];
|
const color = LGraphCanvas.link_type_colors[displayType];
|
||||||
|
|
||||||
|
let widgetConfig;
|
||||||
|
let targetWidget;
|
||||||
|
let widgetType;
|
||||||
// Update the types of each node
|
// Update the types of each node
|
||||||
for (const node of updateNodes) {
|
for (const node of updateNodes) {
|
||||||
// If we dont have an input type we are always wildcard but we'll show the output type
|
// If we dont have an input type we are always wildcard but we'll show the output type
|
||||||
@ -125,10 +138,38 @@ app.registerExtension({
|
|||||||
const link = app.graph.links[l];
|
const link = app.graph.links[l];
|
||||||
if (link) {
|
if (link) {
|
||||||
link.color = color;
|
link.color = color;
|
||||||
|
|
||||||
|
if (app.configuringGraph) continue;
|
||||||
|
const targetNode = app.graph.getNodeById(link.target_id);
|
||||||
|
const targetInput = targetNode.inputs?.[link.target_slot];
|
||||||
|
if (targetInput?.widget) {
|
||||||
|
const config = getWidgetConfig(targetInput);
|
||||||
|
if (!widgetConfig) {
|
||||||
|
widgetConfig = config[1] ?? {};
|
||||||
|
widgetType = config[0];
|
||||||
|
}
|
||||||
|
if (!targetWidget) {
|
||||||
|
targetWidget = targetNode.widgets?.find((w) => w.name === targetInput.widget.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
const merged = mergeIfValid(targetInput, [config[0], widgetConfig]);
|
||||||
|
if (merged.customConfig) {
|
||||||
|
widgetConfig = merged.customConfig;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (const node of updateNodes) {
|
||||||
|
if (widgetConfig && outputType) {
|
||||||
|
node.inputs[0].widget = { name: "value" };
|
||||||
|
setWidgetConfig(node.inputs[0], [widgetType ?? displayType, widgetConfig], targetWidget);
|
||||||
|
} else {
|
||||||
|
setWidgetConfig(node.inputs[0], null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (inputNode) {
|
if (inputNode) {
|
||||||
const link = app.graph.links[inputNode.inputs[0].link];
|
const link = app.graph.links[inputNode.inputs[0].link];
|
||||||
if (link) {
|
if (link) {
|
||||||
@ -173,8 +214,8 @@ app.registerExtension({
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
// naming is inverted with respect to LiteGraphNode.horizontal
|
// naming is inverted with respect to LiteGraphNode.horizontal
|
||||||
// LiteGraphNode.horizontal == true means that
|
// LiteGraphNode.horizontal == true means that
|
||||||
// each slot in the inputs and outputs are layed out horizontally,
|
// each slot in the inputs and outputs are layed out horizontally,
|
||||||
// which is the opposite of the visual orientation of the inputs and outputs as a node
|
// which is the opposite of the visual orientation of the inputs and outputs as a node
|
||||||
content: "Set " + (this.properties.horizontal ? "Horizontal" : "Vertical"),
|
content: "Set " + (this.properties.horizontal ? "Horizontal" : "Vertical"),
|
||||||
callback: () => {
|
callback: () => {
|
||||||
@ -187,7 +228,7 @@ app.registerExtension({
|
|||||||
applyOrientation() {
|
applyOrientation() {
|
||||||
this.horizontal = this.properties.horizontal;
|
this.horizontal = this.properties.horizontal;
|
||||||
if (this.horizontal) {
|
if (this.horizontal) {
|
||||||
// we correct the input position, because LiteGraphNode.horizontal
|
// we correct the input position, because LiteGraphNode.horizontal
|
||||||
// doesn't account for title presence
|
// doesn't account for title presence
|
||||||
// which reroute nodes don't have
|
// which reroute nodes don't have
|
||||||
this.inputs[0].pos = [this.size[0] / 2, 0];
|
this.inputs[0].pos = [this.size[0] / 2, 0];
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import { app } from "../../scripts/app.js";
|
import { app } from "../../scripts/app.js";
|
||||||
|
import { applyTextReplacements } from "../../scripts/utils.js";
|
||||||
// Use widget values and dates in output filenames
|
// Use widget values and dates in output filenames
|
||||||
|
|
||||||
app.registerExtension({
|
app.registerExtension({
|
||||||
@ -7,84 +7,19 @@ app.registerExtension({
|
|||||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
||||||
if (nodeData.name === "SaveImage") {
|
if (nodeData.name === "SaveImage") {
|
||||||
const onNodeCreated = nodeType.prototype.onNodeCreated;
|
const onNodeCreated = nodeType.prototype.onNodeCreated;
|
||||||
|
// When the SaveImage node is created we want to override the serialization of the output name widget to run our S&R
|
||||||
// Simple date formatter
|
|
||||||
const parts = {
|
|
||||||
d: (d) => d.getDate(),
|
|
||||||
M: (d) => d.getMonth() + 1,
|
|
||||||
h: (d) => d.getHours(),
|
|
||||||
m: (d) => d.getMinutes(),
|
|
||||||
s: (d) => d.getSeconds(),
|
|
||||||
};
|
|
||||||
const format =
|
|
||||||
Object.keys(parts)
|
|
||||||
.map((k) => k + k + "?")
|
|
||||||
.join("|") + "|yyy?y?";
|
|
||||||
|
|
||||||
function formatDate(text, date) {
|
|
||||||
return text.replace(new RegExp(format, "g"), function (text) {
|
|
||||||
if (text === "yy") return (date.getFullYear() + "").substring(2);
|
|
||||||
if (text === "yyyy") return date.getFullYear();
|
|
||||||
if (text[0] in parts) {
|
|
||||||
const p = parts[text[0]](date);
|
|
||||||
return (p + "").padStart(text.length, "0");
|
|
||||||
}
|
|
||||||
return text;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// When the SaveImage node is created we want to override the serialization of the output name widget to run our S&R
|
|
||||||
nodeType.prototype.onNodeCreated = function () {
|
nodeType.prototype.onNodeCreated = function () {
|
||||||
const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
|
const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
|
||||||
|
|
||||||
const widget = this.widgets.find((w) => w.name === "filename_prefix");
|
const widget = this.widgets.find((w) => w.name === "filename_prefix");
|
||||||
widget.serializeValue = () => {
|
widget.serializeValue = () => {
|
||||||
return widget.value.replace(/%([^%]+)%/g, function (match, text) {
|
return applyTextReplacements(app, widget.value);
|
||||||
const split = text.split(".");
|
|
||||||
if (split.length !== 2) {
|
|
||||||
// Special handling for dates
|
|
||||||
if (split[0].startsWith("date:")) {
|
|
||||||
return formatDate(split[0].substring(5), new Date());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (text !== "width" && text !== "height") {
|
|
||||||
// Dont warn on standard replacements
|
|
||||||
console.warn("Invalid replacement pattern", text);
|
|
||||||
}
|
|
||||||
return match;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find node with matching S&R property name
|
|
||||||
let nodes = app.graph._nodes.filter((n) => n.properties?.["Node name for S&R"] === split[0]);
|
|
||||||
// If we cant, see if there is a node with that title
|
|
||||||
if (!nodes.length) {
|
|
||||||
nodes = app.graph._nodes.filter((n) => n.title === split[0]);
|
|
||||||
}
|
|
||||||
if (!nodes.length) {
|
|
||||||
console.warn("Unable to find node", split[0]);
|
|
||||||
return match;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (nodes.length > 1) {
|
|
||||||
console.warn("Multiple nodes matched", split[0], "using first match");
|
|
||||||
}
|
|
||||||
|
|
||||||
const node = nodes[0];
|
|
||||||
|
|
||||||
const widget = node.widgets?.find((w) => w.name === split[1]);
|
|
||||||
if (!widget) {
|
|
||||||
console.warn("Unable to find widget", split[1], "on node", split[0], node);
|
|
||||||
return match;
|
|
||||||
}
|
|
||||||
|
|
||||||
return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_");
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
// When any other node is created add a property to alias the node
|
// When any other node is created add a property to alias the node
|
||||||
const onNodeCreated = nodeType.prototype.onNodeCreated;
|
const onNodeCreated = nodeType.prototype.onNodeCreated;
|
||||||
nodeType.prototype.onNodeCreated = function () {
|
nodeType.prototype.onNodeCreated = function () {
|
||||||
const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
|
const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
|
||||||
|
|||||||
@ -71,24 +71,21 @@ function graphEqual(a, b, root = true) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const undoRedo = async (e) => {
|
const undoRedo = async (e) => {
|
||||||
|
const updateState = async (source, target) => {
|
||||||
|
const prevState = source.pop();
|
||||||
|
if (prevState) {
|
||||||
|
target.push(activeState);
|
||||||
|
isOurLoad = true;
|
||||||
|
await app.loadGraphData(prevState, false);
|
||||||
|
activeState = prevState;
|
||||||
|
}
|
||||||
|
}
|
||||||
if (e.ctrlKey || e.metaKey) {
|
if (e.ctrlKey || e.metaKey) {
|
||||||
if (e.key === "y") {
|
if (e.key === "y") {
|
||||||
const prevState = redo.pop();
|
updateState(redo, undo);
|
||||||
if (prevState) {
|
|
||||||
undo.push(activeState);
|
|
||||||
isOurLoad = true;
|
|
||||||
await app.loadGraphData(prevState);
|
|
||||||
activeState = prevState;
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
} else if (e.key === "z") {
|
} else if (e.key === "z") {
|
||||||
const prevState = undo.pop();
|
updateState(undo, redo);
|
||||||
if (prevState) {
|
|
||||||
redo.push(activeState);
|
|
||||||
isOurLoad = true;
|
|
||||||
await app.loadGraphData(prevState);
|
|
||||||
activeState = prevState;
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,10 +1,16 @@
|
|||||||
import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js";
|
import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js";
|
||||||
import { app } from "../../scripts/app.js";
|
import { app } from "../../scripts/app.js";
|
||||||
|
import { applyTextReplacements } from "../../scripts/utils.js";
|
||||||
|
|
||||||
const CONVERTED_TYPE = "converted-widget";
|
const CONVERTED_TYPE = "converted-widget";
|
||||||
const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"];
|
const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"];
|
||||||
const CONFIG = Symbol();
|
const CONFIG = Symbol();
|
||||||
const GET_CONFIG = Symbol();
|
const GET_CONFIG = Symbol();
|
||||||
|
const TARGET = Symbol(); // Used for reroutes to specify the real target widget
|
||||||
|
|
||||||
|
export function getWidgetConfig(slot) {
|
||||||
|
return slot.widget[CONFIG] ?? slot.widget[GET_CONFIG]();
|
||||||
|
}
|
||||||
|
|
||||||
function getConfig(widgetName) {
|
function getConfig(widgetName) {
|
||||||
const { nodeData } = this.constructor;
|
const { nodeData } = this.constructor;
|
||||||
@ -100,7 +106,6 @@ function getWidgetType(config) {
|
|||||||
return { type };
|
return { type };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
function isValidCombo(combo, obj) {
|
function isValidCombo(combo, obj) {
|
||||||
// New input isnt a combo
|
// New input isnt a combo
|
||||||
if (!(obj instanceof Array)) {
|
if (!(obj instanceof Array)) {
|
||||||
@ -121,6 +126,31 @@ function isValidCombo(combo, obj) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function setWidgetConfig(slot, config, target) {
|
||||||
|
if (!slot.widget) return;
|
||||||
|
if (config) {
|
||||||
|
slot.widget[GET_CONFIG] = () => config;
|
||||||
|
slot.widget[TARGET] = target;
|
||||||
|
} else {
|
||||||
|
delete slot.widget;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (slot.link) {
|
||||||
|
const link = app.graph.links[slot.link];
|
||||||
|
if (link) {
|
||||||
|
const originNode = app.graph.getNodeById(link.origin_id);
|
||||||
|
if (originNode.type === "PrimitiveNode") {
|
||||||
|
if (config) {
|
||||||
|
originNode.recreateWidget();
|
||||||
|
} else if(!app.configuringGraph) {
|
||||||
|
originNode.disconnectOutput(0);
|
||||||
|
originNode.onLastDisconnect();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) {
|
export function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) {
|
||||||
if (!config1) {
|
if (!config1) {
|
||||||
config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
|
config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
|
||||||
@ -150,7 +180,7 @@ export function mergeIfValid(output, config2, forceUpdate, recreateWidget, confi
|
|||||||
|
|
||||||
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
|
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
|
||||||
for (const k of keys.values()) {
|
for (const k of keys.values()) {
|
||||||
if (k !== "default" && k !== "forceInput" && k !== "defaultInput") {
|
if (k !== "default" && k !== "forceInput" && k !== "defaultInput" && k !== "control_after_generate" && k !== "multiline") {
|
||||||
let v1 = config1[1][k];
|
let v1 = config1[1][k];
|
||||||
let v2 = config2[1]?.[k];
|
let v2 = config2[1]?.[k];
|
||||||
|
|
||||||
@ -405,11 +435,16 @@ app.registerExtension({
|
|||||||
};
|
};
|
||||||
},
|
},
|
||||||
registerCustomNodes() {
|
registerCustomNodes() {
|
||||||
|
const replacePropertyName = "Run widget replace on values";
|
||||||
class PrimitiveNode {
|
class PrimitiveNode {
|
||||||
constructor() {
|
constructor() {
|
||||||
this.addOutput("connect to widget input", "*");
|
this.addOutput("connect to widget input", "*");
|
||||||
this.serialize_widgets = true;
|
this.serialize_widgets = true;
|
||||||
this.isVirtualNode = true;
|
this.isVirtualNode = true;
|
||||||
|
|
||||||
|
if (!this.properties || !(replacePropertyName in this.properties)) {
|
||||||
|
this.addProperty(replacePropertyName, false, "boolean");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
applyToGraph(extraLinks = []) {
|
applyToGraph(extraLinks = []) {
|
||||||
@ -430,18 +465,29 @@ app.registerExtension({
|
|||||||
}
|
}
|
||||||
|
|
||||||
let links = [...get_links(this).map((l) => app.graph.links[l]), ...extraLinks];
|
let links = [...get_links(this).map((l) => app.graph.links[l]), ...extraLinks];
|
||||||
|
let v = this.widgets?.[0].value;
|
||||||
|
if(v && this.properties[replacePropertyName]) {
|
||||||
|
v = applyTextReplacements(app, v);
|
||||||
|
}
|
||||||
|
|
||||||
// For each output link copy our value over the original widget value
|
// For each output link copy our value over the original widget value
|
||||||
for (const linkInfo of links) {
|
for (const linkInfo of links) {
|
||||||
const node = this.graph.getNodeById(linkInfo.target_id);
|
const node = this.graph.getNodeById(linkInfo.target_id);
|
||||||
const input = node.inputs[linkInfo.target_slot];
|
const input = node.inputs[linkInfo.target_slot];
|
||||||
const widgetName = input.widget.name;
|
let widget;
|
||||||
if (widgetName) {
|
if (input.widget[TARGET]) {
|
||||||
const widget = node.widgets.find((w) => w.name === widgetName);
|
widget = input.widget[TARGET];
|
||||||
if (widget) {
|
} else {
|
||||||
widget.value = this.widgets[0].value;
|
const widgetName = input.widget.name;
|
||||||
if (widget.callback) {
|
if (widgetName) {
|
||||||
widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {});
|
widget = node.widgets.find((w) => w.name === widgetName);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (widget) {
|
||||||
|
widget.value = v;
|
||||||
|
if (widget.callback) {
|
||||||
|
widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -494,14 +540,13 @@ app.registerExtension({
|
|||||||
this.#mergeWidgetConfig();
|
this.#mergeWidgetConfig();
|
||||||
|
|
||||||
if (!links?.length) {
|
if (!links?.length) {
|
||||||
this.#onLastDisconnect();
|
this.onLastDisconnect();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
onConnectOutput(slot, type, input, target_node, target_slot) {
|
onConnectOutput(slot, type, input, target_node, target_slot) {
|
||||||
// Fires before the link is made allowing us to reject it if it isn't valid
|
// Fires before the link is made allowing us to reject it if it isn't valid
|
||||||
|
|
||||||
// No widget, we cant connect
|
// No widget, we cant connect
|
||||||
if (!input.widget) {
|
if (!input.widget) {
|
||||||
if (!(input.type in ComfyWidgets)) return false;
|
if (!(input.type in ComfyWidgets)) return false;
|
||||||
@ -519,6 +564,10 @@ app.registerExtension({
|
|||||||
|
|
||||||
#onFirstConnection(recreating) {
|
#onFirstConnection(recreating) {
|
||||||
// First connection can fire before the graph is ready on initial load so random things can be missing
|
// First connection can fire before the graph is ready on initial load so random things can be missing
|
||||||
|
if (!this.outputs[0].links) {
|
||||||
|
this.onLastDisconnect();
|
||||||
|
return;
|
||||||
|
}
|
||||||
const linkId = this.outputs[0].links[0];
|
const linkId = this.outputs[0].links[0];
|
||||||
const link = this.graph.links[linkId];
|
const link = this.graph.links[linkId];
|
||||||
if (!link) return;
|
if (!link) return;
|
||||||
@ -546,10 +595,10 @@ app.registerExtension({
|
|||||||
this.outputs[0].name = type;
|
this.outputs[0].name = type;
|
||||||
this.outputs[0].widget = widget;
|
this.outputs[0].widget = widget;
|
||||||
|
|
||||||
this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating);
|
this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating, widget[TARGET]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#createWidget(inputData, node, widgetName, recreating) {
|
#createWidget(inputData, node, widgetName, recreating, targetWidget) {
|
||||||
let type = inputData[0];
|
let type = inputData[0];
|
||||||
|
|
||||||
if (type instanceof Array) {
|
if (type instanceof Array) {
|
||||||
@ -563,7 +612,9 @@ app.registerExtension({
|
|||||||
widget = this.addWidget(type, "value", null, () => {}, {});
|
widget = this.addWidget(type, "value", null, () => {}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node?.widgets && widget) {
|
if (targetWidget) {
|
||||||
|
widget.value = targetWidget.value;
|
||||||
|
} else if (node?.widgets && widget) {
|
||||||
const theirWidget = node.widgets.find((w) => w.name === widgetName);
|
const theirWidget = node.widgets.find((w) => w.name === widgetName);
|
||||||
if (theirWidget) {
|
if (theirWidget) {
|
||||||
widget.value = theirWidget.value;
|
widget.value = theirWidget.value;
|
||||||
@ -577,11 +628,19 @@ app.registerExtension({
|
|||||||
}
|
}
|
||||||
addValueControlWidgets(this, widget, control_value, undefined, inputData);
|
addValueControlWidgets(this, widget, control_value, undefined, inputData);
|
||||||
let filter = this.widgets_values?.[2];
|
let filter = this.widgets_values?.[2];
|
||||||
if(filter && this.widgets.length === 3) {
|
if (filter && this.widgets.length === 3) {
|
||||||
this.widgets[2].value = filter;
|
this.widgets[2].value = filter;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Restore any saved control values
|
||||||
|
const controlValues = this.controlValues;
|
||||||
|
if(this.lastType === this.widgets[0].type && controlValues?.length === this.widgets.length - 1) {
|
||||||
|
for(let i = 0; i < controlValues.length; i++) {
|
||||||
|
this.widgets[i + 1].value = controlValues[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// When our value changes, update other widgets to reflect our changes
|
// When our value changes, update other widgets to reflect our changes
|
||||||
// e.g. so LoadImage shows correct image
|
// e.g. so LoadImage shows correct image
|
||||||
const callback = widget.callback;
|
const callback = widget.callback;
|
||||||
@ -610,12 +669,14 @@ app.registerExtension({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#recreateWidget() {
|
recreateWidget() {
|
||||||
const values = this.widgets.map((w) => w.value);
|
const values = this.widgets?.map((w) => w.value);
|
||||||
this.#removeWidgets();
|
this.#removeWidgets();
|
||||||
this.#onFirstConnection(true);
|
this.#onFirstConnection(true);
|
||||||
for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i];
|
if (values?.length) {
|
||||||
return this.widgets[0];
|
for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i];
|
||||||
|
}
|
||||||
|
return this.widgets?.[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
#mergeWidgetConfig() {
|
#mergeWidgetConfig() {
|
||||||
@ -631,7 +692,7 @@ app.registerExtension({
|
|||||||
if (links?.length < 2 && hasConfig) {
|
if (links?.length < 2 && hasConfig) {
|
||||||
// Copy the widget options from the source
|
// Copy the widget options from the source
|
||||||
if (links.length) {
|
if (links.length) {
|
||||||
this.#recreateWidget();
|
this.recreateWidget();
|
||||||
}
|
}
|
||||||
|
|
||||||
return;
|
return;
|
||||||
@ -657,7 +718,7 @@ app.registerExtension({
|
|||||||
// Only allow connections where the configs match
|
// Only allow connections where the configs match
|
||||||
const output = this.outputs[0];
|
const output = this.outputs[0];
|
||||||
const config2 = input.widget[GET_CONFIG]();
|
const config2 = input.widget[GET_CONFIG]();
|
||||||
return !!mergeIfValid.call(this, output, config2, forceUpdate, this.#recreateWidget);
|
return !!mergeIfValid.call(this, output, config2, forceUpdate, this.recreateWidget);
|
||||||
}
|
}
|
||||||
|
|
||||||
#removeWidgets() {
|
#removeWidgets() {
|
||||||
@ -668,11 +729,20 @@ app.registerExtension({
|
|||||||
w.onRemove();
|
w.onRemove();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Temporarily store the current values in case the node is being recreated
|
||||||
|
// e.g. by group node conversion
|
||||||
|
this.controlValues = [];
|
||||||
|
this.lastType = this.widgets[0]?.type;
|
||||||
|
for(let i = 1; i < this.widgets.length; i++) {
|
||||||
|
this.controlValues.push(this.widgets[i].value);
|
||||||
|
}
|
||||||
|
setTimeout(() => { delete this.lastType; delete this.controlValues }, 15);
|
||||||
this.widgets.length = 0;
|
this.widgets.length = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#onLastDisconnect() {
|
onLastDisconnect() {
|
||||||
// We cant remove + re-add the output here as if you drag a link over the same link
|
// We cant remove + re-add the output here as if you drag a link over the same link
|
||||||
// it removes, then re-adds, causing it to break
|
// it removes, then re-adds, causing it to break
|
||||||
this.outputs[0].type = "*";
|
this.outputs[0].type = "*";
|
||||||
|
|||||||
@ -48,7 +48,7 @@
|
|||||||
EVENT_LINK_COLOR: "#A86",
|
EVENT_LINK_COLOR: "#A86",
|
||||||
CONNECTING_LINK_COLOR: "#AFA",
|
CONNECTING_LINK_COLOR: "#AFA",
|
||||||
|
|
||||||
MAX_NUMBER_OF_NODES: 1000, //avoid infinite loops
|
MAX_NUMBER_OF_NODES: 10000, //avoid infinite loops
|
||||||
DEFAULT_POSITION: [100, 100], //default node position
|
DEFAULT_POSITION: [100, 100], //default node position
|
||||||
VALID_SHAPES: ["default", "box", "round", "card"], //,"circle"
|
VALID_SHAPES: ["default", "box", "round", "card"], //,"circle"
|
||||||
|
|
||||||
@ -3788,16 +3788,42 @@
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* returns the bounding of the object, used for rendering purposes
|
* returns the bounding of the object, used for rendering purposes
|
||||||
* bounding is: [topleft_cornerx, topleft_cornery, width, height]
|
|
||||||
* @method getBounding
|
* @method getBounding
|
||||||
* @return {Float32Array[4]} the total size
|
* @param out {Float32Array[4]?} [optional] a place to store the output, to free garbage
|
||||||
|
* @param compute_outer {boolean?} [optional] set to true to include the shadow and connection points in the bounding calculation
|
||||||
|
* @return {Float32Array[4]} the bounding box in format of [topleft_cornerx, topleft_cornery, width, height]
|
||||||
*/
|
*/
|
||||||
LGraphNode.prototype.getBounding = function(out) {
|
LGraphNode.prototype.getBounding = function(out, compute_outer) {
|
||||||
out = out || new Float32Array(4);
|
out = out || new Float32Array(4);
|
||||||
out[0] = this.pos[0] - 4;
|
const nodePos = this.pos;
|
||||||
out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT;
|
const isCollapsed = this.flags.collapsed;
|
||||||
out[2] = this.flags.collapsed ? (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) : this.size[0] + 4;
|
const nodeSize = this.size;
|
||||||
out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT;
|
|
||||||
|
let left_offset = 0;
|
||||||
|
// 1 offset due to how nodes are rendered
|
||||||
|
let right_offset = 1 ;
|
||||||
|
let top_offset = 0;
|
||||||
|
let bottom_offset = 0;
|
||||||
|
|
||||||
|
if (compute_outer) {
|
||||||
|
// 4 offset for collapsed node connection points
|
||||||
|
left_offset = 4;
|
||||||
|
// 6 offset for right shadow and collapsed node connection points
|
||||||
|
right_offset = 6 + left_offset;
|
||||||
|
// 4 offset for collapsed nodes top connection points
|
||||||
|
top_offset = 4;
|
||||||
|
// 5 offset for bottom shadow and collapsed node connection points
|
||||||
|
bottom_offset = 5 + top_offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
out[0] = nodePos[0] - left_offset;
|
||||||
|
out[1] = nodePos[1] - LiteGraph.NODE_TITLE_HEIGHT - top_offset;
|
||||||
|
out[2] = isCollapsed ?
|
||||||
|
(this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) + right_offset :
|
||||||
|
nodeSize[0] + right_offset;
|
||||||
|
out[3] = isCollapsed ?
|
||||||
|
LiteGraph.NODE_TITLE_HEIGHT + bottom_offset :
|
||||||
|
nodeSize[1] + LiteGraph.NODE_TITLE_HEIGHT + bottom_offset;
|
||||||
|
|
||||||
if (this.onBounding) {
|
if (this.onBounding) {
|
||||||
this.onBounding(out);
|
this.onBounding(out);
|
||||||
@ -7674,7 +7700,7 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!overlapBounding(this.visible_area, n.getBounding(temp))) {
|
if (!overlapBounding(this.visible_area, n.getBounding(temp, true))) {
|
||||||
continue;
|
continue;
|
||||||
} //out of the visible area
|
} //out of the visible area
|
||||||
|
|
||||||
@ -11336,6 +11362,7 @@ LGraphNode.prototype.executeAction = function(action)
|
|||||||
name_element.innerText = title;
|
name_element.innerText = title;
|
||||||
var value_element = dialog.querySelector(".value");
|
var value_element = dialog.querySelector(".value");
|
||||||
value_element.value = value;
|
value_element.value = value;
|
||||||
|
value_element.select();
|
||||||
|
|
||||||
var input = value_element;
|
var input = value_element;
|
||||||
input.addEventListener("keydown", function(e) {
|
input.addEventListener("keydown", function(e) {
|
||||||
|
|||||||
@ -1559,9 +1559,12 @@ export class ComfyApp {
|
|||||||
/**
|
/**
|
||||||
* Populates the graph with the specified workflow data
|
* Populates the graph with the specified workflow data
|
||||||
* @param {*} graphData A serialized graph object
|
* @param {*} graphData A serialized graph object
|
||||||
|
* @param { boolean } clean If the graph state, e.g. images, should be cleared
|
||||||
*/
|
*/
|
||||||
async loadGraphData(graphData) {
|
async loadGraphData(graphData, clean = true) {
|
||||||
this.clean();
|
if (clean !== false) {
|
||||||
|
this.clean();
|
||||||
|
}
|
||||||
|
|
||||||
let reset_invalid_values = false;
|
let reset_invalid_values = false;
|
||||||
if (!graphData) {
|
if (!graphData) {
|
||||||
@ -1771,15 +1774,26 @@ export class ComfyApp {
|
|||||||
if (parent?.updateLink) {
|
if (parent?.updateLink) {
|
||||||
link = parent.updateLink(link);
|
link = parent.updateLink(link);
|
||||||
}
|
}
|
||||||
inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
|
if (link) {
|
||||||
|
inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
output[String(node.id)] = {
|
let node_data = {
|
||||||
inputs,
|
inputs,
|
||||||
class_type: node.comfyClass,
|
class_type: node.comfyClass,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (this.ui.settings.getSettingValue("Comfy.DevMode")) {
|
||||||
|
// Ignored by the backend.
|
||||||
|
node_data["_meta"] = {
|
||||||
|
title: node.title,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output[String(node.id)] = node_data;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2006,12 +2020,8 @@ export class ComfyApp {
|
|||||||
async refreshComboInNodes() {
|
async refreshComboInNodes() {
|
||||||
const defs = await api.getNodeDefs();
|
const defs = await api.getNodeDefs();
|
||||||
|
|
||||||
for(const nodeId in LiteGraph.registered_node_types) {
|
for (const nodeId in defs) {
|
||||||
const node = LiteGraph.registered_node_types[nodeId];
|
this.registerNodeDef(nodeId, defs[nodeId]);
|
||||||
const nodeDef = defs[nodeId];
|
|
||||||
if(!nodeDef) continue;
|
|
||||||
|
|
||||||
node.nodeData = nodeDef;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for(let nodeNum in this.graph._nodes) {
|
for(let nodeNum in this.graph._nodes) {
|
||||||
|
|||||||
@ -177,6 +177,7 @@ LGraphCanvas.prototype.computeVisibleNodes = function () {
|
|||||||
for (const w of node.widgets) {
|
for (const w of node.widgets) {
|
||||||
if (w.element) {
|
if (w.element) {
|
||||||
w.element.hidden = hidden;
|
w.element.hidden = hidden;
|
||||||
|
w.element.style.display = hidden ? "none" : undefined;
|
||||||
if (hidden) {
|
if (hidden) {
|
||||||
w.options.onHide?.(w);
|
w.options.onHide?.(w);
|
||||||
}
|
}
|
||||||
|
|||||||
67
web/scripts/utils.js
Normal file
67
web/scripts/utils.js
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
// Simple date formatter
|
||||||
|
const parts = {
|
||||||
|
d: (d) => d.getDate(),
|
||||||
|
M: (d) => d.getMonth() + 1,
|
||||||
|
h: (d) => d.getHours(),
|
||||||
|
m: (d) => d.getMinutes(),
|
||||||
|
s: (d) => d.getSeconds(),
|
||||||
|
};
|
||||||
|
const format =
|
||||||
|
Object.keys(parts)
|
||||||
|
.map((k) => k + k + "?")
|
||||||
|
.join("|") + "|yyy?y?";
|
||||||
|
|
||||||
|
function formatDate(text, date) {
|
||||||
|
return text.replace(new RegExp(format, "g"), function (text) {
|
||||||
|
if (text === "yy") return (date.getFullYear() + "").substring(2);
|
||||||
|
if (text === "yyyy") return date.getFullYear();
|
||||||
|
if (text[0] in parts) {
|
||||||
|
const p = parts[text[0]](date);
|
||||||
|
return (p + "").padStart(text.length, "0");
|
||||||
|
}
|
||||||
|
return text;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
export function applyTextReplacements(app, value) {
|
||||||
|
return value.replace(/%([^%]+)%/g, function (match, text) {
|
||||||
|
const split = text.split(".");
|
||||||
|
if (split.length !== 2) {
|
||||||
|
// Special handling for dates
|
||||||
|
if (split[0].startsWith("date:")) {
|
||||||
|
return formatDate(split[0].substring(5), new Date());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (text !== "width" && text !== "height") {
|
||||||
|
// Dont warn on standard replacements
|
||||||
|
console.warn("Invalid replacement pattern", text);
|
||||||
|
}
|
||||||
|
return match;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find node with matching S&R property name
|
||||||
|
let nodes = app.graph._nodes.filter((n) => n.properties?.["Node name for S&R"] === split[0]);
|
||||||
|
// If we cant, see if there is a node with that title
|
||||||
|
if (!nodes.length) {
|
||||||
|
nodes = app.graph._nodes.filter((n) => n.title === split[0]);
|
||||||
|
}
|
||||||
|
if (!nodes.length) {
|
||||||
|
console.warn("Unable to find node", split[0]);
|
||||||
|
return match;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nodes.length > 1) {
|
||||||
|
console.warn("Multiple nodes matched", split[0], "using first match");
|
||||||
|
}
|
||||||
|
|
||||||
|
const node = nodes[0];
|
||||||
|
|
||||||
|
const widget = node.widgets?.find((w) => w.name === split[1]);
|
||||||
|
if (!widget) {
|
||||||
|
console.warn("Unable to find widget", split[1], "on node", split[0], node);
|
||||||
|
return match;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_");
|
||||||
|
});
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user