Merge upstream, fix 3.12 compatibility, fix nightlies issue, fix broken node

This commit is contained in:
doctorpangloss 2024-01-03 16:00:36 -08:00
commit 369aeb598f
75 changed files with 2798 additions and 1075 deletions

View File

@ -1,3 +0,0 @@
..\python_embeded\python.exe .\update.py ..\ComfyUI\
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2
pause

View File

@ -1,2 +0,0 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --use-pytorch-cross-attention
pause

View File

@ -22,5 +22,5 @@ jobs:
run: |
npm ci
npm run test:generate
npm test
npm test -- --verbose
working-directory: ./tests-ui

View File

@ -2,6 +2,24 @@ name: "Windows Release Nightly pytorch"
on:
workflow_dispatch:
inputs:
cu:
description: 'cuda version'
required: true
type: string
default: "121"
python_minor:
description: 'python minor version'
required: true
type: string
default: "12"
python_patch:
description: 'python patch version'
required: true
type: string
default: "1"
# push:
# branches:
# - master
@ -20,21 +38,21 @@ jobs:
persist-credentials: false
- uses: actions/setup-python@v4
with:
python-version: '3.11.6'
python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }}
- shell: bash
run: |
cd ..
cp -r ComfyUI ComfyUI_copy
curl https://www.python.org/ftp/python/3.11.6/python-3.11.6-embed-amd64.zip -o python_embeded.zip
curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip
unzip python_embeded.zip -d python_embeded
cd python_embeded
echo 'import site' >> ./python311._pth
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
python -m pip wheel torch torchvision torchaudio aiohttp==3.8.5 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python311._pth
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
cd ..
git clone https://github.com/comfyanonymous/taesd
@ -49,9 +67,10 @@ jobs:
mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/nightly/update_windows/* ./update/
cp -r ComfyUI/.ci/nightly/windows_base_files/* ./
echo "..\python_embeded\python.exe .\update.py ..\ComfyUI\\
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
pause" > ./update/update_comfyui_and_python_dependencies.bat
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch

View File

@ -1,9 +0,0 @@
{
"path-intellisense.mappings": {
"../": "${workspaceFolder}/web/extensions/core"
},
"[python]": {
"editor.defaultFormatter": "ms-python.autopep8"
},
"python.formatting.provider": "none"
}

0
__init__.py Normal file
View File

View File

@ -53,7 +53,7 @@ class ControlNet(nn.Module):
transformer_depth_middle=None,
transformer_depth_output=None,
device=None,
operations=ops,
operations=ops.disable_weight_init,
**kwargs,
):
super().__init__()
@ -141,24 +141,24 @@ class ControlNet(nn.Module):
)
]
)
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)])
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
self.input_hint_block = TimestepEmbedSequential(
operations.conv_nd(dims, hint_channels, 16, 3, padding=1),
operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 16, 16, 3, padding=1),
operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2),
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 32, 32, 3, padding=1),
operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2),
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 96, 96, 3, padding=1),
operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2),
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1))
operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
)
self._feature_size = model_channels
@ -206,7 +206,7 @@ class ControlNet(nn.Module):
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
@ -234,7 +234,7 @@ class ControlNet(nn.Module):
)
ch = out_ch
input_block_chans.append(ch)
self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
ds *= 2
self._feature_size += ch
@ -276,14 +276,14 @@ class ControlNet(nn.Module):
operations=operations
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self.middle_block_out = self.make_zero_conv(ch, operations=operations)
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
self._feature_size += ch
def make_zero_conv(self, channels, operations=None):
return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
guided_hint = self.input_hint_block(hint, emb, context)
@ -295,7 +295,7 @@ class ControlNet(nn.Module):
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
h = x
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None:
h = module(h, emb, context)

View File

@ -55,13 +55,19 @@ fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
fpunet_group = parser.add_mutually_exclusive_group()
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
fpte_group = parser.add_mutually_exclusive_group()
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
@ -98,7 +104,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")

188
comfy/clip_model.py Normal file
View File

@ -0,0 +1,188 @@
import torch
from comfy.ldm.modules.attention import optimized_attention_for_device
class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations):
super().__init__()
self.heads = heads
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x, mask=None, optimized_attention=None):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
out = optimized_attention(q, k, v, self.heads, mask)
return self.out_proj(out)
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
"gelu": torch.nn.functional.gelu,
}
class CLIPMLP(torch.nn.Module):
def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
super().__init__()
self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
self.activation = ACTIVATIONS[activation]
self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x
class CLIPLayer(torch.nn.Module):
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
super().__init__()
self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
def forward(self, x, mask=None, optimized_attention=None):
x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
x += self.mlp(self.layer_norm2(x))
return x
class CLIPEncoder(torch.nn.Module):
def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
super().__init__()
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None)
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
for i, l in enumerate(self.layers):
x = l(x, mask, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class CLIPEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
super().__init__()
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens):
return self.token_embedding(input_tokens) + self.position_embedding.weight
class CLIPTextModel_(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
num_layers = config_dict["num_hidden_layers"]
embed_dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
x = self.embeddings(input_tokens)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
if mask is not None:
mask += causal_mask
else:
mask = causal_mask
x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
x = self.final_layer_norm(x)
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)
pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
return x, i, pooled_output
class CLIPTextModel(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.num_layers = config_dict["num_hidden_layers"]
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
self.dtype = dtype
def get_input_embeddings(self):
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, embeddings):
self.text_model.embeddings.token_embedding = embeddings
def forward(self, *args, **kwargs):
return self.text_model(*args, **kwargs)
class CLIPVisionEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
super().__init__()
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
self.patch_embedding = operations.Conv2d(
in_channels=num_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=False,
dtype=dtype,
device=device
)
num_patches = (image_size // patch_size) ** 2
num_positions = num_patches + 1
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, pixel_values):
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device)
class CLIPVision(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
num_layers = config_dict["num_hidden_layers"]
embed_dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=torch.float32, device=device, operations=operations)
self.pre_layrnorm = operations.LayerNorm(embed_dim)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.post_layernorm = operations.LayerNorm(embed_dim)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
x = self.embeddings(pixel_values)
x = self.pre_layrnorm(x)
#TODO: attention_mask?
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output
class CLIPVisionModelProjection(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
def forward(self, *args, **kwargs):
x = self.vision_model(*args, **kwargs)
out = self.visual_projection(x[2])
return (x[0], x[1], out)

View File

@ -1,36 +1,43 @@
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils
from .utils import load_torch_file, transformers_convert
import os
import torch
import contextlib
import json
from . import ops
from . import model_patcher
from . import model_management
from . import clip_model
class Output:
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, item):
setattr(self, key, item)
def clip_preprocess(image, size=224):
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
scale = (size / min(image.shape[1], image.shape[2]))
image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
scale = (size / min(image.shape[2], image.shape[3]))
image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
class ClipVisionModel():
def __init__(self, json_config):
config = CLIPVisionConfig.from_json_file(json_config)
with open(json_config) as f:
config = json.load(f)
self.load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
self.dtype = torch.float32
if model_management.should_use_fp16(self.load_device, prioritize_performance=False):
self.dtype = torch.float16
with ops.use_comfy_ops(offload_device, self.dtype):
with modeling_utils.no_init_weights():
self.model = CLIPVisionModelWithProjection(config)
self.model.to(self.dtype)
self.dtype = model_management.text_encoder_dtype(self.load_device)
self.model = clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, ops.manual_cast)
self.model.eval()
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
@ -38,25 +45,13 @@ class ClipVisionModel():
def encode_image(self, image):
model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device))
if self.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)
with precision_scope(model_management.get_autocast_device(self.load_device), torch.float32):
outputs = self.model(pixel_values=pixel_values, output_hidden_states=True)
for k in outputs:
t = outputs[k]
if t is not None:
if k == 'hidden_states':
outputs["penultimate_hidden_states"] = t[-2].cpu()
outputs["hidden_states"] = None
else:
outputs[k] = t.cpu()
pixel_values = clip_preprocess(image.to(self.load_device)).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
outputs = Output()
outputs["last_hidden_state"] = out[0].to(model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(model_management.intermediate_device())
outputs["penultimate_hidden_states"] = out[1].to(model_management.intermediate_device())
return outputs
def convert_to_transformers(sd, prefix):
@ -86,6 +81,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
if convert_keys:
sd = convert_to_transformers(sd, prefix)
if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
# todo: fix the importlib issue here
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")

View File

@ -13,6 +13,8 @@ import typing
from dataclasses import dataclass
from typing import Tuple
import sys
import gc
import inspect
import torch
@ -442,6 +444,8 @@ class PromptExecutor:
for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x])
self.server.last_node_id = None
if model_management.DISABLE_SMART_MEMORY:
model_management.unload_all_models()
@ -462,6 +466,14 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
errors = []
valid = True
# todo: investigate if these are at the right indent level
info = None
val = None
validate_function_inputs = []
if hasattr(obj_class, "VALIDATE_INPUTS"):
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
for x in required_inputs:
if x not in inputs:
error = {
@ -591,29 +603,7 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
errors.append(error)
continue
if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all = get_input_data(inputs, obj_class, unique_id)
# ret = obj_class.VALIDATE_INPUTS(**input_data_all)
ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS")
for i, r3 in enumerate(ret):
if r3 is not True:
details = f"{x}"
if r3 is not False:
details += f" - {str(r3)}"
error = {
"type": "custom_validation_failed",
"message": "Custom validation failed for node",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
else:
if x not in validate_function_inputs:
if isinstance(type_input, list):
if val not in type_input:
input_config = info
@ -640,6 +630,35 @@ def validate_inputs(prompt, item, validated) -> Tuple[bool, typing.List[dict], t
errors.append(error)
continue
if len(validate_function_inputs) > 0:
input_data_all = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs:
input_filtered[x] = input_data_all[x]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
for x in input_filtered:
for i, r in enumerate(ret):
if r is not True:
details = f"{x}"
if r is not False:
details += f" - {str(r)}"
error = {
"type": "custom_validation_failed",
"message": "Custom validation failed for node",
"details": details,
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if len(errors) > 0 or valid is not True:
ret = (False, errors, unique_id)
else:
@ -771,7 +790,7 @@ class PromptQueue:
self.server.queue_updated()
self.not_empty.notify()
def get(self, timeout=None) -> typing.Tuple[QueueTuple, int]:
def get(self, timeout=None) -> typing.Optional[typing.Tuple[QueueTuple, int]]:
with self.not_empty:
while len(self.queue) == 0:
self.not_empty.wait(timeout=timeout)

View File

@ -188,8 +188,7 @@ def cached_filename_list_(folder_name):
if folder_name not in filename_list_cache:
return None
out = filename_list_cache[folder_name]
if time.perf_counter() < (out[2] + 0.5):
return out
for x in out[1]:
time_modified = out[1][x]
folder = x

View File

@ -23,9 +23,9 @@ def execute_prestartup_script():
return False
node_paths = folder_paths.get_folder_paths("custom_nodes")
node_prestartup_times = []
for custom_node_path in node_paths:
possible_modules = os.listdir(custom_node_path) if os.path.exists(custom_node_path) else []
node_prestartup_times = []
for possible_module in possible_modules:
module_path = os.path.join(custom_node_path, possible_module)
@ -69,6 +69,10 @@ if args.cuda_device is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
print("Set cuda device to:", args.cuda_device)
if args.deterministic:
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
from .. import utils
import yaml
@ -78,12 +82,12 @@ from .server import BinaryEventTypes
from .. import model_management
def prompt_worker(q: execution.PromptQueue, _server: server_module.PromptServer):
def prompt_worker(q, _server):
e = execution.PromptExecutor(_server)
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
current_time = 0.0
while True:
timeout = None
if need_gc:
@ -94,11 +98,13 @@ def prompt_worker(q: execution.PromptQueue, _server: server_module.PromptServer)
item, item_id = queue_item
execution_start_time = time.perf_counter()
prompt_id = item[1]
_server.last_prompt_id = prompt_id
e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True
q.task_done(item_id, e.outputs_ui)
if _server.client_id is not None:
_server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, _server.client_id)
_server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, _server.client_id)
current_time = time.perf_counter()
execution_time = current_time - execution_start_time
@ -119,7 +125,10 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
def hijack_progress(server):
def hook(value, total, preview_image):
server.send_sync("progress", {"value": value, "max": total}, server.client_id)
model_management.throw_exception_if_processing_interrupted()
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
server.send_sync("progress", progress, server.client_id)
if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
@ -204,7 +213,7 @@ def main():
print(f"Setting output directory to: {output_dir}")
folder_paths.set_output_directory(output_dir)
#These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
# These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))

View File

@ -734,7 +734,8 @@ class PromptServer():
message = self.encode_bytes(event, data)
if sid is None:
for ws in self.sockets.values():
sockets = list(self.sockets.values())
for ws in sockets:
await send_socket_catch_exception(ws.send_bytes, message)
elif sid in self.sockets:
await send_socket_catch_exception(self.sockets[sid].send_bytes, message)
@ -743,7 +744,8 @@ class PromptServer():
message = {"type": event, "data": data}
if sid is None:
for ws in self.sockets.values():
sockets = list(self.sockets.values())
for ws in sockets:
await send_socket_catch_exception(ws.send_json, message)
elif sid in self.sockets:
await send_socket_catch_exception(self.sockets[sid].send_json, message)

View File

@ -1,10 +1,13 @@
import torch
import math
import os
import contextlib
from . import utils
from . import model_management
from . import model_detection
from . import model_patcher
from . import ops
from .cldm import cldm
from .t2i_adapter import adapter
@ -34,13 +37,13 @@ class ControlBase:
self.cond_hint = None
self.strength = 1.0
self.timestep_percent_range = (0.0, 1.0)
self.global_average_pooling = False
self.timestep_range = None
if device is None:
device = model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
self.global_average_pooling = False
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
self.cond_hint_original = cond_hint
@ -75,6 +78,7 @@ class ControlBase:
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range
c.global_average_pooling = self.global_average_pooling
def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
@ -127,12 +131,14 @@ class ControlBase:
return out
class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None):
def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
super().__init__(device)
self.control_model = control_model
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
self.load_device = load_device
self.control_model_wrapped = model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling
self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
@ -146,28 +152,31 @@ class ControlNet(ControlBase):
else:
return None
dtype = self.control_model.dtype
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
context = cond['c_crossattn']
y = cond.get('y', None)
if y is not None:
y = y.to(self.control_model.dtype)
y = y.to(dtype)
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype)
def copy(self):
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
self.copy_to(c)
return c
@ -198,10 +207,11 @@ class ControlLoraOps:
self.bias = None
def forward(self, input):
weight, bias = ops.cast_bias_weight(self, input)
if self.up is not None:
return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else:
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
return torch.nn.functional.linear(input, weight, bias)
class Conv2d(torch.nn.Module):
def __init__(
@ -237,16 +247,11 @@ class ControlLoraOps:
def forward(self, input):
weight, bias = ops.cast_bias_weight(self, input)
if self.up is not None:
return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
def conv_nd(self, dims, *args, **kwargs):
if dims == 2:
return self.Conv2d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
class ControlLora(ControlNet):
@ -260,17 +265,26 @@ class ControlLora(ControlNet):
controlnet_config = model.model_config.unet_config.copy()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
controlnet_config["operations"] = ControlLoraOps()
self.control_model = cldm.ControlNet(**controlnet_config)
self.manual_cast_dtype = model.manual_cast_dtype
dtype = model.get_dtype()
self.control_model.to(dtype)
if self.manual_cast_dtype is None:
class control_lora_ops(ControlLoraOps, ops.disable_weight_init):
pass
else:
class control_lora_ops(ControlLoraOps, ops.manual_cast):
pass
dtype = self.manual_cast_dtype
controlnet_config["operations"] = control_lora_ops
controlnet_config["dtype"] = dtype
self.control_model = cldm.ControlNet(**controlnet_config)
self.control_model.to(model_management.get_torch_device())
diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict()
cm = self.control_model.state_dict()
for k in sd:
weight = model_management.resolve_lowvram_weight(sd[k], diffusion_model, k)
weight = sd[k]
try:
utils.set_attr(self.control_model, k, weight)
except:
@ -367,6 +381,10 @@ def load_controlnet(ckpt_path, model=None):
if controlnet_config is None:
unet_dtype = model_management.unet_dtype()
controlnet_config = model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
controlnet_config["operations"] = ops.manual_cast
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = cldm.ControlNet(**controlnet_config)
@ -395,14 +413,12 @@ def load_controlnet(ckpt_path, model=None):
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
print(missing, unexpected)
control_model = control_model.to(unet_dtype)
global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling)
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
class T2IAdapter(ControlBase):

View File

@ -33,3 +33,7 @@ class SDXL(LatentFormat):
[-0.3112, -0.2359, -0.2076]
]
self.taesd_decoder_name = "taesdxl_decoder"
class SD_X4(LatentFormat):
def __init__(self):
self.scale_factor = 0.08333

View File

@ -9,6 +9,7 @@ from ..modules.distributions.distributions import DiagonalGaussianDistribution
from ..util import instantiate_from_config, get_obj_from_str
from ..modules.ema import LitEma
from ... import ops
class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = True):
@ -162,12 +163,12 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
},
**kwargs,
)
self.quant_conv = torch.nn.Conv2d(
self.quant_conv = ops.disable_weight_init.Conv2d(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim,
1,
)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.post_quant_conv = ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
def get_autoencoder_params(self) -> list:

View File

@ -18,6 +18,7 @@ if model_management.xformers_enabled():
from ...cli_args import args
from ... import ops
ops = ops.disable_weight_init
# CrossAttn precision handling
if args.dont_upcast_attention:
@ -82,16 +83,6 @@ class FeedForward(nn.Module):
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
@ -112,19 +103,20 @@ def attention_basic(q, k, v, heads, mask=None):
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * scale
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
if mask.dtype == torch.bool:
mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
else:
sim += mask
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
@ -349,6 +341,18 @@ else:
if model_management.pytorch_attention_enabled():
optimized_attention_masked = attention_pytorch
def optimized_attention_for_device(device, mask=False):
if device == torch.device("cpu"): #TODO
if model_management.pytorch_attention_enabled():
return attention_pytorch
else:
return attention_basic
if mask:
return optimized_attention_masked
return optimized_attention
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
super().__init__()
@ -393,7 +397,7 @@ class BasicTransformerBlock(nn.Module):
self.is_res = inner_dim == dim
if self.ff_in:
self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device)
self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
self.disable_self_attn = disable_self_attn
@ -413,10 +417,10 @@ class BasicTransformerBlock(nn.Module):
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.checkpoint = checkpoint
self.n_heads = n_heads
self.d_head = d_head
@ -558,7 +562,7 @@ class SpatialTransformer(nn.Module):
context_dim = [context_dim] * depth
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels, dtype=dtype, device=device)
self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
if not use_linear:
self.proj_in = operations.Conv2d(in_channels,
inner_dim,

View File

@ -8,6 +8,7 @@ from typing import Optional, Any
from .... import model_management
from .... import ops
ops = ops.disable_weight_init
if model_management.xformers_enabled_vae():
import xformers
@ -40,7 +41,7 @@ def nonlinearity(x):
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):

View File

@ -12,13 +12,13 @@ from .util import (
checkpoint,
avg_pool_nd,
zero_module,
normalization,
timestep_embedding,
AlphaBlender,
)
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from ...util import exists
from .... import ops
ops = ops.disable_weight_init
class TimestepBlock(nn.Module):
"""
@ -177,7 +177,7 @@ class ResBlock(TimestepBlock):
padding = kernel_size // 2
self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels, dtype=dtype, device=device),
operations.GroupNorm(32, channels, dtype=dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
)
@ -206,12 +206,11 @@ class ResBlock(TimestepBlock):
),
)
self.out_layers = nn.Sequential(
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
),
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
,
)
if self.out_channels == channels:
@ -438,9 +437,6 @@ class UNetModel(nn.Module):
operations=ops,
):
super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
@ -457,7 +453,6 @@ class UNetModel(nn.Module):
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
@ -503,7 +498,7 @@ class UNetModel(nn.Module):
if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
@ -810,13 +805,13 @@ class UNetModel(nn.Module):
self._feature_size += ch
self.out = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
nn.SiLU(),
zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
@ -842,14 +837,14 @@ class UNetModel(nn.Module):
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
h = x
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)

View File

@ -41,10 +41,14 @@ class AbstractLowScaleModel(nn.Module):
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
def q_sample(self, x_start, t, noise=None, seed=None):
if noise is None:
if seed is None:
noise = torch.randn_like(x_start)
else:
noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device)
return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise)
def forward(self, x):
return x, None
@ -69,12 +73,12 @@ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
super().__init__(noise_schedule_config=noise_schedule_config)
self.max_noise_level = max_noise_level
def forward(self, x, noise_level=None):
def forward(self, x, noise_level=None, seed=None):
if noise_level is None:
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
else:
assert isinstance(noise_level, torch.Tensor)
z = self.q_sample(x, noise_level)
z = self.q_sample(x, noise_level, seed=seed)
return z, noise_level

View File

@ -16,7 +16,6 @@ import numpy as np
from einops import repeat, rearrange
from ...util import instantiate_from_config
from .... import ops
class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"]
@ -52,9 +51,9 @@ class AlphaBlender(nn.Module):
if self.merge_strategy == "fixed":
# make shape compatible
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
alpha = self.mix_factor
alpha = self.mix_factor.to(image_only_indicator.device)
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor)
alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device))
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
elif self.merge_strategy == "learned_with_images":
@ -62,7 +61,7 @@ class AlphaBlender(nn.Module):
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"),
)
alpha = rearrange(alpha, self.rearrange_pattern)
# make shape compatible
@ -273,46 +272,6 @@ def mean_flat(tensor):
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels, dtype=None):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels, dtype=dtype)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return ops.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return ops.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.

View File

@ -15,12 +15,12 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
def scale(self, x):
# re-normalize to centered mean and unit variance
x = (x - self.data_mean) * 1. / self.data_std
x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device)
return x
def unscale(self, x):
# back to original data stats
x = (x * self.data_std) + self.data_mean
x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device)
return x
def forward(self, x, noise_level=None):

View File

@ -5,6 +5,7 @@ import torch
from einops import rearrange, repeat
from ... import ops
ops = ops.disable_weight_init
from .diffusionmodules.model import (
AttnBlock,
@ -81,14 +82,14 @@ class VideoResBlock(ResnetBlock):
x = self.time_stack(x, temb)
alpha = self.get_alpha(bs=b // timesteps)
alpha = self.get_alpha(bs=b // timesteps).to(x.device)
x = alpha * x + (1.0 - alpha) * x_mix
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
class AE3DConv(torch.nn.Conv2d):
class AE3DConv(ops.Conv2d):
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
super().__init__(in_channels, out_channels, *args, **kwargs)
if isinstance(video_kernel_size, Iterable):
@ -96,7 +97,7 @@ class AE3DConv(torch.nn.Conv2d):
else:
padding = int(video_kernel_size // 2)
self.time_mix_conv = torch.nn.Conv3d(
self.time_mix_conv = ops.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=video_kernel_size,
@ -166,7 +167,7 @@ class AttnVideoBlock(AttnBlock):
emb = emb[:, None, :]
x_mix = x_mix + emb
alpha = self.get_alpha()
alpha = self.get_alpha().to(x.device)
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge

View File

@ -43,7 +43,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
loaded_keys.add(A_name)
loaded_keys.add(B_name)
@ -64,7 +64,7 @@ def load_lora(lora, to_load):
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)
patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
@ -116,8 +116,19 @@ def load_lora(lora, to_load):
loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
#glora
a1_name = "{}.a1.weight".format(x)
a2_name = "{}.a2.weight".format(x)
b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x)
if a1_name in lora:
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
loaded_keys.add(a1_name)
loaded_keys.add(a2_name)
loaded_keys.add(b1_name)
loaded_keys.add(b2_name)
w_norm_name = "{}.w_norm".format(x)
b_norm_name = "{}.b_norm".format(x)
@ -126,21 +137,21 @@ def load_lora(lora, to_load):
if w_norm is not None:
loaded_keys.add(w_norm_name)
patch_dict[to_load[x]] = (w_norm,)
patch_dict[to_load[x]] = ("diff", (w_norm,))
if b_norm is not None:
loaded_keys.add(b_norm_name)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
diff_name = "{}.diff".format(x)
diff_weight = lora.get(diff_name, None)
if diff_weight is not None:
patch_dict[to_load[x]] = (diff_weight,)
patch_dict[to_load[x]] = ("diff", (diff_weight,))
loaded_keys.add(diff_name)
diff_bias_name = "{}.diff_b".format(x)
diff_bias = lora.get(diff_bias_name, None)
if diff_bias is not None:
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
loaded_keys.add(diff_bias_name)
for x in lora.keys():

View File

@ -1,10 +1,12 @@
import torch
from .ldm.modules.diffusionmodules.openaimodel import UNetModel
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from .ldm.modules.diffusionmodules.openaimodel import Timestep
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from . import model_management
from . import conds
from . import ops
from enum import Enum
import contextlib
from . import utils
class ModelType(Enum):
@ -13,7 +15,7 @@ class ModelType(Enum):
V_PREDICTION_EDM = 3
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
from .model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM
def model_sampling(model_config, model_type):
@ -40,9 +42,14 @@ class BaseModel(torch.nn.Module):
unet_config = model_config.unet_config
self.latent_format = model_config.latent_format
self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype
if not unet_config.get("disable_unet_model_creation", False):
self.diffusion_model = UNetModel(**unet_config, device=device)
if self.manual_cast_dtype is not None:
operations = ops.manual_cast
else:
operations = ops.disable_weight_init
self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
@ -61,15 +68,21 @@ class BaseModel(torch.nn.Module):
context = c_crossattn
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
xc = xc.to(dtype)
t = self.model_sampling.timestep(t).float()
context = context.to(dtype)
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "to"):
extra = extra.to(dtype)
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
extra_conds[o] = extra
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
@ -117,6 +130,10 @@ class BaseModel(torch.nn.Module):
adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = conds.CONDRegular(adm)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
return out
def load_model_weights(self, sd, unet_prefix=""):
@ -144,11 +161,7 @@ class BaseModel(torch.nn.Module):
def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
unet_sd = self.diffusion_model.state_dict()
unet_state_dict = {}
for k in unet_sd:
unet_state_dict[k] = model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16:
@ -165,9 +178,12 @@ class BaseModel(torch.nn.Module):
def memory_required(self, input_shape):
if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024)
return (area * model_management.dtype_size(dtype) / 50) * (1024 * 1024)
else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * input_shape[2] * input_shape[3]
@ -307,9 +323,75 @@ class SVD_img2vid(BaseModel):
out['c_concat'] = conds.CONDNoiseShape(latent_image)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
if "time_conditioning" in kwargs:
out["time_context"] = conds.CONDCrossAttn(kwargs["time_conditioning"])
out['image_only_indicator'] = conds.CONDConstant(torch.zeros((1,), device=device))
out['num_video_frames'] = conds.CONDConstant(noise.shape[0])
return out
class Stable_Zero123(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
super().__init__(model_config, model_type, device=device)
self.cc_projection = ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device)
self.cc_projection.weight.copy_(cc_projection_weight)
self.cc_projection.bias.copy_(cc_projection_bias)
def extra_conds(self, **kwargs):
out = {}
latent_image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
if latent_image is None:
latent_image = torch.zeros_like(noise)
if latent_image.shape[1:] != noise.shape[1:]:
latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0])
out['c_concat'] = conds.CONDNoiseShape(latent_image)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
if cross_attn.shape[-1] != 768:
cross_attn = self.cc_projection(cross_attn)
out['c_crossattn'] = conds.CONDCrossAttn(cross_attn)
return out
class SD_X4Upscaler(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device)
self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350)
def extra_conds(self, **kwargs):
out = {}
image = kwargs.get("concat_image", None)
noise = kwargs.get("noise", None)
noise_augment = kwargs.get("noise_augmentation", 0.0)
device = kwargs["device"]
seed = kwargs["seed"] - 10
noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment)
if image is None:
image = torch.zeros_like(noise)[:,:3]
if image.shape[1:] != noise.shape[1:]:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
noise_level = torch.tensor([noise_level], device=device)
if noise_augment > 0:
image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed)
image = utils.resize_to_batch_size(image, noise.shape[0])
out['c_concat'] = conds.CONDNoiseShape(image)
out['y'] = conds.CONDRegular(noise_level)
return out

View File

@ -34,7 +34,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
unet_config = {
"use_checkpoint": False,
"image_size": 32,
"out_channels": 4,
"use_spatial_transformer": True,
"legacy": False
}
@ -50,6 +49,12 @@ def detect_unet_config(state_dict, key_prefix, dtype):
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
out_key = '{}out.2.weight'.format(key_prefix)
if out_key in state_dict:
out_channels = state_dict[out_key].shape[0]
else:
out_channels = 4
num_res_blocks = []
channel_mult = []
attention_resolutions = []
@ -122,6 +127,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
transformer_depth_middle = -1
unet_config["in_channels"] = in_channels
unet_config["out_channels"] = out_channels
unet_config["model_channels"] = model_channels
unet_config["num_res_blocks"] = num_res_blocks
unet_config["transformer_depth"] = transformer_depth
@ -289,7 +295,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B]
Segmind_Vega = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 1, 1, 2, 2], 'transformer_depth_output': [0, 0, 0, 1, 1, 1, 2, 2, 2],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega]
for unet_config in supported_models:
matches = True

View File

@ -2,6 +2,7 @@ import psutil
from enum import Enum
from .cli_args import args
from . import utils
import torch
import sys
@ -28,6 +29,10 @@ total_vram = 0
lowvram_available = True
xpu_available = False
if args.deterministic:
print("Using deterministic algorithms for pytorch")
torch.use_deterministic_algorithms(True, warn_only=True)
directml_enabled = False
if args.directml is not None:
import torch_directml
@ -182,6 +187,9 @@ except:
if is_intel_xpu():
VAE_DTYPE = torch.bfloat16
if args.cpu_vae:
VAE_DTYPE = torch.float32
if args.fp16_vae:
VAE_DTYPE = torch.float16
elif args.bf16_vae:
@ -214,15 +222,8 @@ if args.force_fp16 or cpu_state == CPUState.MPS:
FORCE_FP16 = True
if lowvram_available:
try:
import accelerate
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to
except Exception as e:
import traceback
print(traceback.format_exc())
print("ERROR: LOW VRAM MODE NEEDS accelerate.")
lowvram_available = False
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to
if cpu_state != CPUState.GPU:
@ -262,6 +263,14 @@ print("VAE dtype:", VAE_DTYPE)
current_loaded_models = []
def module_size(module):
module_mem = 0
sd = module.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nelement() * t.element_size()
return module_mem
class LoadedModel:
def __init__(self, model):
self.model = model
@ -294,8 +303,20 @@ class LoadedModel:
if lowvram_model_memory > 0:
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"})
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
mem_counter = 0
for m in self.real_model.modules():
if hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
module_mem = module_size(m)
if mem_counter + module_mem < lowvram_model_memory:
m.to(self.device)
mem_counter += module_mem
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
m.to(self.device)
mem_counter += module_size(m)
print("lowvram: loaded module regularly", m)
self.model_accelerated = True
if is_intel_xpu() and not args.disable_ipex_optimize:
@ -305,7 +326,11 @@ class LoadedModel:
def model_unload(self):
if self.model_accelerated:
accelerate.hooks.remove_hook_from_submodules(self.real_model)
for m in self.real_model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
self.model_accelerated = False
self.model.unpatch_model(self.model.offload_device)
@ -398,14 +423,14 @@ def load_models_gpu(models, memory_required=0):
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM
else:
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 256 * 1024 * 1024
lowvram_model_memory = 64 * 1024 * 1024
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
current_loaded_models.insert(0, loaded_model)
@ -430,6 +455,13 @@ def dtype_size(dtype):
dtype_size = 4
if dtype == torch.float16 or dtype == torch.bfloat16:
dtype_size = 2
elif dtype == torch.float32:
dtype_size = 4
else:
try:
dtype_size = dtype.itemsize
except: #Old pytorch doesn't have .itemsize
pass
return dtype_size
def unet_offload_device():
@ -459,10 +491,30 @@ def unet_inital_load_device(parameters, dtype):
def unet_dtype(device=None, model_params=0):
if args.bf16_unet:
return torch.bfloat16
if args.fp16_unet:
return torch.float16
if args.fp8_e4m3fn_unet:
return torch.float8_e4m3fn
if args.fp8_e5m2_unet:
return torch.float8_e5m2
if should_use_fp16(device=device, model_params=model_params):
return torch.float16
return torch.float32
# None means no manual cast
def unet_manual_cast(weight_dtype, inference_device):
if weight_dtype == torch.float32:
return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
if fp16_supported and weight_dtype == torch.float16:
return None
if fp16_supported:
return torch.float16
else:
return torch.float32
def text_encoder_offload_device():
if args.gpu_only:
return get_torch_device()
@ -492,12 +544,23 @@ def text_encoder_dtype(device=None):
elif args.fp32_text_enc:
return torch.float32
if is_device_cpu(device):
return torch.float16
if should_use_fp16(device, prioritize_performance=False):
return torch.float16
else:
return torch.float32
def intermediate_device():
if args.gpu_only:
return get_torch_device()
else:
return torch.device("cpu")
def vae_device():
if args.cpu_vae:
return torch.device("cpu")
return get_torch_device()
def vae_offload_device():
@ -515,6 +578,22 @@ def get_autocast_device(dev):
return dev.type
return "cuda"
def supports_dtype(device, dtype): #TODO
if dtype == torch.float32:
return True
if is_device_cpu(device):
return False
if dtype == torch.float16:
return True
if dtype == torch.bfloat16:
return True
return False
def device_supports_non_blocking(device):
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking
return True
def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = False
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
@ -525,15 +604,17 @@ def cast_to_device(tensor, device, dtype, copy=False):
elif is_intel_xpu():
device_supports_cast = True
non_blocking = device_supports_non_blocking(device)
if device_supports_cast:
if copy:
if tensor.device == device:
return tensor.to(dtype, copy=copy)
return tensor.to(device, copy=copy).to(dtype)
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
return tensor.to(device).to(dtype)
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
return tensor.to(dtype).to(device, copy=copy)
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
def xformers_enabled():
global directml_enabled
@ -687,11 +768,11 @@ def soft_empty_cache(force=False):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def resolve_lowvram_weight(weight, model, key):
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = utils.get_attr(model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]]
def unload_all_models():
free_memory(1e30, get_torch_device())
def resolve_lowvram_weight(weight, model, key): #TODO: remove
return weight
#TODO: might be cleaner to put this somewhere else

View File

@ -28,13 +28,9 @@ class ModelPatcher:
if self.size > 0:
return self.size
model_sd = self.model.state_dict()
size = 0
for k in model_sd:
t = model_sd[k]
size += t.nelement() * t.element_size()
self.size = size
self.size = model_management.module_size(self.model)
self.model_keys = set(model_sd.keys())
return size
return self.size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
@ -55,11 +51,18 @@ class ModelPatcher:
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)
def set_model_sampler_cfg_function(self, sampler_cfg_function):
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True
def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function
@ -70,13 +73,17 @@ class ModelPatcher:
to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_patch_replace(self, patch, name, block_name, number):
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
to = self.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
to["patches_replace"][name][(block_name, number)] = patch
if transformer_index is not None:
block = (block_name, number, transformer_index)
else:
block = (block_name, number)
to["patches_replace"][name][block] = patch
def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch")
@ -84,11 +91,11 @@ class ModelPatcher:
def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch")
def set_model_attn1_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn1", block_name, number)
def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
def set_model_attn2_replace(self, patch, block_name, number):
self.set_model_patch_replace(patch, "attn2", block_name, number)
def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
def set_model_attn1_output_patch(self, patch):
self.set_model_patch(patch, "attn1_output_patch")
@ -167,40 +174,41 @@ class ModelPatcher:
sd.pop(k)
return sd
def patch_model(self, device_to=None):
def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches:
old = getattr(self.model, k)
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
setattr(self.model, k, self.object_patches[k])
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
print("could not patch. key doesn't exist in model:", key)
continue
if patch_weights:
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
print("could not patch. key doesn't exist in model:", key)
continue
weight = model_sd[key]
weight = model_sd[key]
inplace_update = self.weight_inplace_update
inplace_update = self.weight_inplace_update
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
if device_to is not None:
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
if inplace_update:
utils.copy_to_param(self.model, key, out_weight)
else:
utils.set_attr(self.model, key, out_weight)
del temp_weight
if device_to is not None:
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
if inplace_update:
utils.copy_to_param(self.model, key, out_weight)
else:
utils.set_attr(self.model, key, out_weight)
del temp_weight
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
self.model.to(device_to)
self.current_device = device_to
return self.model
@ -217,13 +225,19 @@ class ModelPatcher:
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
w1 = v[0]
if alpha != 0.0:
if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * model_management.cast_to_device(w1, weight.device, weight.dtype)
elif len(v) == 4: #lora/locon
elif patch_type == "lora": #lora/locon
mat1 = model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = model_management.cast_to_device(v[1], weight.device, torch.float32)
if v[2] is not None:
@ -237,7 +251,7 @@ class ModelPatcher:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif len(v) == 8: #lokr
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
@ -276,7 +290,7 @@ class ModelPatcher:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
else: #loha
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
@ -305,6 +319,18 @@ class ModelPatcher:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif patch_type == "glora":
if v[4] is not None:
alpha *= v[4] / v[0].shape[0]
a1 = model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
b2 = model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
else:
print("patch type not recognized", patch_type, key)
return weight

View File

@ -22,10 +22,17 @@ class V_PREDICTION(EPS):
class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
beta_schedule = "linear"
if model_config is not None:
beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
beta_schedule = sampling_settings.get("beta_schedule", "linear")
linear_start = sampling_settings.get("linear_start", 0.00085)
linear_end = sampling_settings.get("linear_end", 0.012)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
self.sigma_data = 1.0
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,

View File

@ -6,7 +6,7 @@ import hashlib
import math
import random
from PIL import Image, ImageOps
from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo
import numpy as np
import safetensors.torch
@ -930,8 +930,8 @@ class GLIGENTextBoxApply:
return (c, )
class EmptyLatentImage:
def __init__(self, device="cpu"):
self.device = device
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@classmethod
def INPUT_TYPES(s):
@ -944,7 +944,7 @@ class EmptyLatentImage:
CATEGORY = "latent"
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
return ({"samples":latent}, )
@ -1395,17 +1395,30 @@ class LoadImage:
FUNCTION = "load_image"
def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path)
i = ImageOps.exif_transpose(i)
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
img = Image.open(image_path)
output_images = []
output_masks = []
for i in ImageSequence.Iterator(img):
i = ImageOps.exif_transpose(i)
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
if len(output_images) > 1:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
return (image, mask.unsqueeze(0))
output_image = output_images[0]
output_mask = output_masks[0]
return (output_image, output_mask)
@classmethod
def IS_CHANGED(s, image):
@ -1463,13 +1476,10 @@ class LoadImageMask:
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, image, channel):
def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
if channel not in s._color_channels:
return "Invalid color channel: {}".format(channel)
return True
class ImageScale:

View File

@ -1,40 +1,115 @@
import torch
from contextlib import contextmanager
import comfy.model_management
class Linear(torch.nn.Linear):
def reset_parameters(self):
return None
def cast_bias_weight(s, input):
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
return weight, bias
class Conv2d(torch.nn.Conv2d):
def reset_parameters(self):
return None
class Conv3d(torch.nn.Conv3d):
def reset_parameters(self):
return None
class disable_weight_init:
class Linear(torch.nn.Linear):
comfy_cast_weights = False
def reset_parameters(self):
return None
def conv_nd(dims, *args, **kwargs):
if dims == 2:
return Conv2d(*args, **kwargs)
elif dims == 3:
return Conv3d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
@contextmanager
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
old_torch_nn_linear = torch.nn.Linear
force_device = device
force_dtype = dtype
def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
if force_device is not None:
device = force_device
if force_dtype is not None:
dtype = force_dtype
return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
torch.nn.Linear = linear_with_dtype
try:
yield
finally:
torch.nn.Linear = old_torch_nn_linear
class Conv2d(torch.nn.Conv2d):
comfy_cast_weights = False
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv3d(torch.nn.Conv3d):
comfy_cast_weights = False
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class GroupNorm(torch.nn.GroupNorm):
comfy_cast_weights = False
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm):
comfy_cast_weights = False
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@classmethod
def conv_nd(s, dims, *args, **kwargs):
if dims == 2:
return s.Conv2d(*args, **kwargs)
elif dims == 3:
return s.Conv3d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
class manual_cast(disable_weight_init):
class Linear(disable_weight_init.Linear):
comfy_cast_weights = True
class Conv2d(disable_weight_init.Conv2d):
comfy_cast_weights = True
class Conv3d(disable_weight_init.Conv3d):
comfy_cast_weights = True
class GroupNorm(disable_weight_init.GroupNorm):
comfy_cast_weights = True
class LayerNorm(disable_weight_init.LayerNorm):
comfy_cast_weights = True

View File

@ -0,0 +1,9 @@
from importlib.resources import path
import os
def get_editable_resource_path(caller_file, *package_path):
filename = os.path.join(os.path.dirname(os.path.realpath(caller_file)), package_path[-1])
if not os.path.exists(filename):
filename = path(*package_path)
return filename

View File

@ -47,7 +47,8 @@ def convert_cond(cond):
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0])
model_conds["c_crossattn"] = conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
out.append(temp)
return out
@ -98,10 +99,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler = samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.cpu()
samples = samples.to(model_management.intermediate_device())
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
@ -111,8 +112,8 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent
sigmas = sigmas.to(model.load_device)
samples = samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.cpu()
samples = samples.to(model_management.intermediate_device())
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
return samples

View File

@ -1,256 +1,264 @@
from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc
import torch
import collections
from . import model_management
import math
def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0
if 'timestep_start' in conds:
timestep_start = conds['timestep_start']
if timestep_in[0] > timestep_start:
return None
if 'timestep_end' in conds:
timestep_end = conds['timestep_end']
if timestep_in[0] < timestep_end:
return None
if 'area' in conds:
area = conds['area']
if 'strength' in conds:
strength = conds['strength']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
if 'mask' in conds:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength = 1.0
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
else:
mask = torch.ones_like(input_x)
mult = mask * strength
if 'mask' not in conds:
rr = 8
if area[2] != 0:
for t in range(rr):
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
if (area[0] + area[2]) < x_in.shape[2]:
for t in range(rr):
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
if area[3] != 0:
for t in range(rr):
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr):
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditioning = {}
model_conds = conds["model_conds"]
for c in model_conds:
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
control = conds.get('control', None)
patches = None
if 'gligen' in conds:
gligen = conds['gligen']
patches = {}
gligen_type = gligen[0]
gligen_model = gligen[1]
if gligen_type == "position":
gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
else:
gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
patches['middle_patch'] = [gligen_patch]
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches'])
return cond_obj(input_x, mult, conditioning, area, control, patches)
def cond_equal_size(c1, c2):
if c1 is c2:
return True
if c1.keys() != c2.keys():
return False
for k in c1:
if not c1[k].can_concat(c2[k]):
return False
return True
def can_concat_cond(c1, c2):
if c1.input_x.shape != c2.input_x.shape:
return False
def objects_concatable(obj1, obj2):
if (obj1 is None) != (obj2 is None):
return False
if obj1 is not None:
if obj1 is not obj2:
return False
return True
if not objects_concatable(c1.control, c2.control):
return False
if not objects_concatable(c1.patches, c2.patches):
return False
return cond_equal_size(c1.conditioning, c2.conditioning)
def cond_cat(c_list):
c_crossattn = []
c_concat = []
c_adm = []
crossattn_max_len = 0
temp = {}
for x in c_list:
for k in x:
cur = temp.get(k, [])
cur.append(x[k])
temp[k] = cur
out = {}
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
return out
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in) * 1e-37
out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in) * 1e-37
COND = 0
UNCOND = 1
to_run = []
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, COND)]
if uncond is not None:
for x in uncond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, UNCOND)]
while len(to_run) > 0:
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = model_management.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) < free_memory:
to_batch = batch_amount
break
input_x = []
mult = []
c = []
cond_or_uncond = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
control = p.control
patches = p.patches
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
if patches is not None:
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
else:
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep
c['transformer_options'] = transformer_options
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
for o in range(batch_chunks):
if cond_or_uncond[o] == COND:
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
else:
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
del mult
out_cond /= out_count
del out_count
out_uncond /= out_uncond_count
del out_uncond_count
return out_cond, out_uncond
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0
if 'timestep_start' in conds:
timestep_start = conds['timestep_start']
if timestep_in[0] > timestep_start:
return None
if 'timestep_end' in conds:
timestep_end = conds['timestep_end']
if timestep_in[0] < timestep_end:
return None
if 'area' in conds:
area = conds['area']
if 'strength' in conds:
strength = conds['strength']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
if 'mask' in conds:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength = 1.0
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
else:
mask = torch.ones_like(input_x)
mult = mask * strength
if 'mask' not in conds:
rr = 8
if area[2] != 0:
for t in range(rr):
mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1))
if (area[0] + area[2]) < x_in.shape[2]:
for t in range(rr):
mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1))
if area[3] != 0:
for t in range(rr):
mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1))
if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr):
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditionning = {}
model_conds = conds["model_conds"]
for c in model_conds:
conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
control = None
if 'control' in conds:
control = conds['control']
patches = None
if 'gligen' in conds:
gligen = conds['gligen']
patches = {}
gligen_type = gligen[0]
gligen_model = gligen[1]
if gligen_type == "position":
gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device)
else:
gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device)
patches['middle_patch'] = [gligen_patch]
return (input_x, mult, conditionning, area, control, patches)
def cond_equal_size(c1, c2):
if c1 is c2:
return True
if c1.keys() != c2.keys():
return False
for k in c1:
if not c1[k].can_concat(c2[k]):
return False
return True
def can_concat_cond(c1, c2):
if c1[0].shape != c2[0].shape:
return False
#control
if (c1[4] is None) != (c2[4] is None):
return False
if c1[4] is not None:
if c1[4] is not c2[4]:
return False
#patches
if (c1[5] is None) != (c2[5] is None):
return False
if (c1[5] is not None):
if c1[5] is not c2[5]:
return False
return cond_equal_size(c1[2], c2[2])
def cond_cat(c_list):
c_crossattn = []
c_concat = []
c_adm = []
crossattn_max_len = 0
temp = {}
for x in c_list:
for k in x:
cur = temp.get(k, [])
cur.append(x[k])
temp[k] = cur
out = {}
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
return out
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in) * 1e-37
out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in) * 1e-37
COND = 0
UNCOND = 1
to_run = []
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, COND)]
if uncond is not None:
for x in uncond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, UNCOND)]
while len(to_run) > 0:
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = model_management.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) < free_memory:
to_batch = batch_amount
break
input_x = []
mult = []
c = []
cond_or_uncond = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x += [p[0]]
mult += [p[1]]
c += [p[2]]
area += [p[3]]
cond_or_uncond += [o[1]]
control = p[4]
patches = p[5]
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
if patches is not None:
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
else:
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep
c['transformer_options'] = transformer_options
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
for o in range(batch_chunks):
if cond_or_uncond[o] == COND:
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
else:
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
del mult
out_cond /= out_count
del out_count
out_uncond /= out_uncond_count
del out_uncond_count
return out_cond, out_uncond
if math.isclose(cond_scale, 1.0):
uncond = None
cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
return x - model_options["sampler_cfg_function"](args)
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None
else:
return uncond + (cond - uncond) * cond_scale
uncond_ = uncond
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)
return cfg_result
class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model):
@ -272,10 +280,7 @@ class KSamplerX0Inpaint(torch.nn.Module):
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
if denoise_mask is not None:
out *= denoise_mask
if denoise_mask is not None:
out += self.latent_image * latent_mask
out = out * denoise_mask + self.latent_image * latent_mask
return out
def simple_scheduler(model, steps):
@ -590,6 +595,13 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model, positive)
if latent_image is not None:
latent_image = model.process_latent_in(latent_image)
if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
#make sure each cond area has an opposite one with the same area
for c in positive:
create_cond_with_same_area_if_none(negative, c)
@ -601,13 +613,6 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if latent_image is not None:
latent_image = model.process_latent_in(latent_image)
if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
@ -630,7 +635,7 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps):
elif scheduler_name == "sgm_uniform":
sigmas = normal_scheduler(model, steps, sgm=True)
else:
print("error invalid scheduler", self.scheduler)
print("error invalid scheduler", scheduler_name)
return sigmas
def sampler_object(name):

View File

@ -148,12 +148,14 @@ class CLIP:
return self.patcher.get_key_patches()
class VAE:
def __init__(self, sd=None, device=None, config=None):
def __init__(self, sd=None, device=None, config=None, dtype=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
self.downscale_ratio = 8
self.latent_channels = 4
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
@ -169,6 +171,11 @@ class VAE:
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE
ddconfig['ch_mult'] = [1, 2, 4]
self.downscale_ratio = 4
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else:
self.first_stage_model = AutoencoderKL(**(config['params']))
@ -185,8 +192,11 @@ class VAE:
device = model_management.vae_device()
self.device = device
offload_device = model_management.vae_offload_device()
self.vae_dtype = model_management.vae_dtype()
if dtype is None:
dtype = model_management.vae_dtype()
self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype)
self.output_device = model_management.intermediate_device()
self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
@ -198,9 +208,9 @@ class VAE:
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
output = torch.clamp((
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar))
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) +
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar))
/ 3.0) / 2.0, min=0.0, max=1.0)
return output
@ -211,9 +221,9 @@ class VAE:
pbar = utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples /= 3.0
return samples
@ -225,15 +235,15 @@ class VAE:
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0)
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0)
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)
pixel_samples = pixel_samples.cpu().movedim(1,-1)
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
@ -249,10 +259,10 @@ class VAE:
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
@ -429,11 +439,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
parameters = utils.calculate_parameters(sd, "model.diffusion_model.")
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
class WeightsLoader(torch.nn.Module):
pass
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
model_config.set_manual_cast(manual_cast_dtype)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
@ -466,7 +480,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
print("left over keys:", left_over)
if output_model:
_model_patcher = model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
_model_patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"):
print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)
@ -477,6 +491,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def load_unet_state_dict(sd): #load unet in diffusers format
parameters = utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
if "input_blocks.0.0.weight" in sd: #ldm
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
if model_config is None:
@ -497,13 +514,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format
else:
print(diffusers_keys[k], k)
offload_device = model_management.unet_offload_device()
model_config.set_manual_cast(manual_cast_dtype)
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:
print("left over keys in unet:", left_over)
return model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
def load_unet(unet_path):
sd = utils.load_torch_file(unet_path)

View File

@ -1,6 +1,6 @@
import os
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils
from transformers import CLIPTokenizer
from . import ops
import torch
import traceback
@ -8,6 +8,8 @@ import zipfile
from . import model_management
from pkg_resources import resource_filename
import contextlib
from . import clip_model
import json
def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None)
@ -38,7 +40,7 @@ class ClipTokenWeightEncoder:
out, pooled = self.encode(to_encode)
if pooled is not None:
first_pooled = pooled[0:1].cpu()
first_pooled = pooled[0:1].to(model_management.intermediate_device())
else:
first_pooled = pooled
@ -55,8 +57,8 @@ class ClipTokenWeightEncoder:
output.append(z)
if (len(output) == 0):
return out[-1:].cpu(), first_pooled
return torch.cat(output, dim=-2).cpu(), first_pooled
return out[-1:].to(model_management.intermediate_device()), first_pooled
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
@ -66,33 +68,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None,
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig,
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=clip_model.CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.num_layers = 12
if textmodel_path is not None:
self.transformer = model_class.from_pretrained(textmodel_path)
else:
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
if not os.path.exists(textmodel_json_config):
textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json')
config = config_class.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers
with ops.use_comfy_ops(device, dtype):
with modeling_utils.no_init_weights():
self.transformer = model_class(config)
self.inner_name = inner_name
if dtype is not None:
self.transformer.to(dtype)
inner_model = getattr(self.transformer, self.inner_name)
if hasattr(inner_model, "embeddings"):
inner_model.embeddings.to(torch.float32)
else:
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
if not os.path.exists(textmodel_json_config):
textmodel_json_config = resource_filename('comfy', 'sd1_clip_config.json')
with open(textmodel_json_config) as f:
config = json.load(f)
self.transformer = model_class(config, dtype, device, ops.manual_cast)
self.num_layers = self.transformer.num_layers
self.max_length = max_length
if freeze:
@ -107,7 +97,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer_norm_hidden_state = layer_norm_hidden_state
if layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) <= self.num_layers
assert abs(layer_idx) < self.num_layers
self.clip_layer(layer_idx)
self.layer_default = (self.layer, self.layer_idx)
@ -118,7 +108,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
param.requires_grad = False
def clip_layer(self, layer_idx):
if abs(layer_idx) >= self.num_layers:
if abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
self.layer = "hidden"
@ -173,41 +163,31 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
precision_scope = torch.autocast
attention_mask = None
if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs[0]
else:
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
z = outputs[1]
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
attention_mask = None
if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break
if outputs[2] is not None:
pooled_output = outputs[2].float()
else:
pooled_output = None
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden")
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
if hasattr(outputs, "pooler_output"):
pooled_output = outputs.pooler_output.float()
else:
pooled_output = None
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output
def encode(self, tokens):

View File

@ -4,15 +4,15 @@ from . import sd1_clip
import os
class SD2ClipHModel(sd1_clip.SDClipModel):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
if layer == "penultimate":
layer="hidden"
layer_idx=23
layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
if not os.path.exists(textmodel_json_config):
textmodel_json_config = resource_filename('comfy', 'sd2_clip_config.json')
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):

View File

@ -3,13 +3,13 @@ import torch
import os
class SDXLClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
if layer == "penultimate":
layer="hidden"
layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype,
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
def load_sd(self, sd):
@ -37,7 +37,7 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx):

View File

@ -217,6 +217,16 @@ class SSD1B(SDXL):
"use_temporal_attention": False,
}
class Segmind_Vega(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 0, 1, 1, 2, 2],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
class SVD_img2vid(supported_models_base.BASE):
unet_config = {
"model_channels": 320,
@ -242,5 +252,59 @@ class SVD_img2vid(supported_models_base.BASE):
def clip_target(self):
return None
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]
class Stable_Zero123(supported_models_base.BASE):
unet_config = {
"context_dim": 768,
"model_channels": 320,
"use_linear_in_transformer": False,
"adm_in_channels": None,
"use_temporal_attention": False,
"in_channels": 8,
}
unet_extra_config = {
"num_heads": 8,
"num_head_channels": -1,
}
clip_vision_prefix = "cond_stage_model.model.visual."
latent_format = latent_formats.SD15
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
return out
def clip_target(self):
return None
class SD_X4Upscaler(SD20):
unet_config = {
"context_dim": 1024,
"model_channels": 256,
'in_channels': 7,
"use_linear_in_transformer": True,
"adm_in_channels": None,
"use_temporal_attention": False,
}
unet_extra_config = {
"disable_self_attentions": [True, True, True, False],
"num_classes": 1000,
"num_heads": 8,
"num_head_channels": -1,
}
latent_format = latent_formats.SD_X4
sampling_settings = {
"linear_start": 0.0001,
"linear_end": 0.02,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SD_X4Upscaler(self, device=device)
return out
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler]
models += [SVD_img2vid]

View File

@ -22,6 +22,8 @@ class BASE:
sampling_settings = {}
latent_format = latent_formats.LatentFormat
manual_cast_dtype = None
@classmethod
def matches(s, unet_config):
for k in s.unet_config:
@ -71,3 +73,5 @@ class BASE:
replace_prefix = {"": "first_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def set_manual_cast(self, manual_cast_dtype):
self.manual_cast_dtype = manual_cast_dtype

View File

@ -7,9 +7,10 @@ import torch
import torch.nn as nn
from .. import utils
from .. import ops
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
return ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
@ -19,7 +20,7 @@ class Block(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.skip = ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU()
def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x))

View File

@ -378,7 +378,7 @@ def lanczos(samples, width, height):
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
result = torch.stack(images)
return result
return result.to(samples.device, samples.dtype)
def common_upscale(samples, width, height, upscale_method, crop):
if crop == "center":
@ -407,17 +407,17 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
@torch.inference_mode()
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None):
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu")
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device)
for b in range(samples.shape[0]):
s = samples[b:b+1]
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
for y in range(0, s.shape[2], tile_y - overlap):
for x in range(0, s.shape[3], tile_x - overlap):
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
ps = function(s_in).cpu()
ps = function(s_in).to(output_device)
mask = torch.ones_like(ps)
feather = round(overlap * upscale_amount)
for t in range(feather):

View File

@ -291,7 +291,7 @@ class Canny:
def detect_edge(self, image, low_threshold, high_threshold):
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1)
img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
return (img_out,)
NODE_CLASS_MAPPINGS = {

View File

@ -1,9 +1,9 @@
import comfy.samplers
import comfy.sample
from comfy import samplers
from comfy import sample
from comfy.k_diffusion import sampling as k_diffusion_sampling
from comfy.cmd import latent_preview
import torch
import comfy.utils
from comfy import utils
class BasicScheduler:
@ -11,8 +11,9 @@ class BasicScheduler:
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"scheduler": (comfy.samplers.SCHEDULER_NAMES, ),
"scheduler": (samplers.SCHEDULER_NAMES, ),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
@ -20,8 +21,15 @@ class BasicScheduler:
FUNCTION = "get_sigmas"
def get_sigmas(self, model, scheduler, steps):
sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, steps).cpu()
def get_sigmas(self, model, scheduler, steps, denoise):
total_steps = steps
if denoise < 1.0:
total_steps = int(steps/denoise)
inner_model = model.patch_model(patch_weights=False)
sigmas = samplers.calculate_sigmas_scheduler(inner_model, scheduler, total_steps).cpu()
model.unpatch_model()
sigmas = sigmas[-(steps + 1):]
return (sigmas, )
@ -87,6 +95,7 @@ class SDTurboScheduler:
return {"required":
{"model": ("MODEL",),
"steps": ("INT", {"default": 1, "min": 1, "max": 10}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
@ -94,9 +103,12 @@ class SDTurboScheduler:
FUNCTION = "get_sigmas"
def get_sigmas(self, model, steps):
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps]
sigmas = model.model.model_sampling.sigma(timesteps)
def get_sigmas(self, model, steps, denoise):
start_step = 10 - int(10 * denoise)
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
inner_model = model.patch_model(patch_weights=False)
sigmas = inner_model.model_sampling.sigma(timesteps)
model.unpatch_model()
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return (sigmas, )
@ -159,7 +171,7 @@ class KSamplerSelect:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"sampler_name": (comfy.samplers.SAMPLER_NAMES, ),
{"sampler_name": (samplers.SAMPLER_NAMES, ),
}
}
RETURN_TYPES = ("SAMPLER",)
@ -168,7 +180,7 @@ class KSamplerSelect:
FUNCTION = "get_sampler"
def get_sampler(self, sampler_name):
sampler = comfy.samplers.sampler_object(sampler_name)
sampler = samplers.sampler_object(sampler_name)
return (sampler, )
class SamplerDPMPP_2M_SDE:
@ -191,7 +203,7 @@ class SamplerDPMPP_2M_SDE:
sampler_name = "dpmpp_2m_sde"
else:
sampler_name = "dpmpp_2m_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
sampler = samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
return (sampler, )
@ -215,7 +227,7 @@ class SamplerDPMPP_SDE:
sampler_name = "dpmpp_sde"
else:
sampler_name = "dpmpp_sde_gpu"
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
sampler = samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
return (sampler, )
class SamplerCustom:
@ -248,7 +260,7 @@ class SamplerCustom:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else:
batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = comfy.sample.prepare_noise(latent_image, noise_seed, batch_inds)
noise = sample.prepare_noise(latent_image, noise_seed, batch_inds)
noise_mask = None
if "noise_mask" in latent:
@ -257,8 +269,8 @@ class SamplerCustom:
x0_output = {}
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
disable_pbar = not utils.PROGRESS_BAR_ENABLED
samples = sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
out = latent.copy()
out["samples"] = samples

View File

@ -2,9 +2,10 @@
import math
from einops import rearrange
import random
# Use torch rng for consistency across generations
from torch import randint
def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int:
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
min_value = min(min_value, value)
# All big divisors of value (inclusive)
@ -12,8 +13,10 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter
ns = [value // i for i in divisors[:max_options]] # has at least 1 element
random.seed(counter)
idx = random.randint(0, len(ns) - 1)
if len(ns) - 1 > 0:
idx = randint(low=0, high=len(ns) - 1, size=(1,)).item()
else:
idx = 0
return ns[idx]
@ -42,7 +45,6 @@ class HyperTile:
latent_tile_size = max(32, tile_size) // 8
self.temp = None
self.counter = 1
def hypertile_in(q, k, v, extra_options):
if q.shape[-1] in apply_to:
@ -53,10 +55,8 @@ class HyperTile:
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1
nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter)
self.counter += 1
nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter)
self.counter += 1
nh = random_divisor(h, latent_tile_size * factor, swap_size)
nw = random_divisor(w, latent_tile_size * factor, swap_size)
if nh * nw > 1:
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)

View File

@ -73,7 +73,7 @@ class SaveAnimatedWEBP:
OUTPUT_NODE = True
CATEGORY = "_for_testing"
CATEGORY = "image/animation"
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
method = self.methods.get(method)
@ -135,7 +135,7 @@ class SaveAnimatedPNG:
OUTPUT_NODE = True
CATEGORY = "_for_testing"
CATEGORY = "image/animation"
def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append

View File

@ -3,9 +3,7 @@ import torch
def reshape_latent_to(target_shape, latent):
if latent.shape[1:] != target_shape[1:]:
latent.movedim(1, -1)
latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center")
latent.movedim(-1, 1)
return comfy.utils.repeat_to_batch_size(latent, target_shape[0])
@ -102,9 +100,32 @@ class LatentInterpolate:
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,)
class LatentBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "batch"
CATEGORY = "latent/batch"
def batch(self, samples1, samples2):
samples_out = samples1.copy()
s1 = samples1["samples"]
s2 = samples2["samples"]
if s1.shape[1:] != s2.shape[1:]:
s2 = comfy.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center")
s = torch.cat((s1, s2), dim=0)
samples_out["samples"] = s
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
return (samples_out,)
NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate,
"LatentBatch": LatentBatch,
}

View File

@ -7,6 +7,7 @@ from comfy.nodes.common import MAX_RESOLUTION
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
source = source.to(destination.device)
if resize_source:
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
@ -21,7 +22,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
if mask is None:
mask = torch.ones_like(source)
else:
mask = mask.clone()
mask = mask.to(destination.device, copy=True)
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])

View File

@ -16,41 +16,19 @@ class LCM(comfy.model_sampling.EPS):
return c_out * x0 + c_skip * model_input
class ModelSamplingDiscreteDistilled(torch.nn.Module):
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
original_timesteps = 50
def __init__(self):
super().__init__()
self.sigma_data = 1.0
timesteps = 1000
beta_start = 0.00085
beta_end = 0.012
def __init__(self, model_config=None):
super().__init__(model_config)
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
self.skip_steps = self.num_timesteps // self.original_timesteps
self.skip_steps = timesteps // self.original_timesteps
alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
for x in range(self.original_timesteps):
alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps]
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
self.set_sigmas(sigmas_valid)
def timestep(self, sigma):
log_sigma = sigma.log()
@ -65,14 +43,6 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module):
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp().to(timestep.device)
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 999999999.9
if percent >= 1.0:
return 0.0
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item()
def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
@ -121,7 +91,7 @@ class ModelSamplingDiscrete:
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced()
model_sampling = ModelSamplingAdvanced(model.model.model_config)
if zsnr:
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
@ -153,7 +123,7 @@ class ModelSamplingContinuousEDM:
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
pass
model_sampling = ModelSamplingAdvanced()
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_sigma_range(sigma_min, sigma_max)
m.add_object_patch("model_sampling", model_sampling)
return (m, )

View File

@ -0,0 +1,53 @@
import torch
from comfy import sample
from comfy import samplers
class PerpNeg:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"empty_conditioning": ("CONDITIONING", ),
"neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, empty_conditioning, neg_scale):
m = model.clone()
nocond = sample.convert_cond(empty_conditioning)
def cfg_function(args):
model = args["model"]
noise_pred_pos = args["cond_denoised"]
noise_pred_neg = args["uncond_denoised"]
cond_scale = args["cond_scale"]
x = args["input"]
sigma = args["sigma"]
model_options = args["model_options"]
nocond_processed = samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
(noise_pred_nocond, _) = samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options)
pos = noise_pred_pos - noise_pred_nocond
neg = noise_pred_neg - noise_pred_nocond
perp = ((torch.mul(pos, neg).sum())/(torch.norm(neg)**2)) * neg
perp_neg = perp * neg_scale
cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg)
cfg_result = x - cfg_result
return cfg_result
m.set_model_sampler_cfg_function(cfg_function)
return (m, )
NODE_CLASS_MAPPINGS = {
"PerpNeg": PerpNeg,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PerpNeg": "Perp-Neg",
}

View File

@ -226,7 +226,7 @@ class Sharpen:
batch_size, height, width, channels = image.shape
kernel_size = sharpen_radius * 2 + 1
kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10)
kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
center = kernel_size // 2
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)

View File

@ -99,10 +99,40 @@ class LatentRebatch:
return (output_list,)
class ImageRebatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "images": ("IMAGE",),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("IMAGE",)
INPUT_IS_LIST = True
OUTPUT_IS_LIST = (True, )
FUNCTION = "rebatch"
CATEGORY = "image/batch"
def rebatch(self, images, batch_size):
batch_size = batch_size[0]
output_list = []
all_images = []
for img in images:
for i in range(img.shape[0]):
all_images.append(img[i:i+1])
for i in range(0, len(all_images), batch_size):
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
return (output_list,)
NODE_CLASS_MAPPINGS = {
"RebatchLatents": LatentRebatch,
"RebatchImages": ImageRebatch,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"RebatchLatents": "Rebatch Latents",
}
"RebatchImages": "Rebatch Images",
}

View File

@ -0,0 +1,168 @@
import torch
from torch import einsum
import torch.nn.functional as F
import math
from einops import rearrange, repeat
import os
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
from comfy import samplers
# from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output
def attention_basic_with_sim(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
h = heads
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
del q, k
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return (out, sim)
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
# reshape and GAP the attention map
_, hw1, hw2 = attn.shape
b, _, lh, lw = x0.shape
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
ratio = math.ceil(math.sqrt(lh * lw / hw1))
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
# Reshape
mask = (
mask.reshape(b, *mid_shape)
.unsqueeze(1)
.type(attn.dtype)
)
# Upsample
mask = F.interpolate(mask, (lh, lw))
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
blurred = blurred * mask + x0 * (1 - mask)
return blurred
def gaussian_blur_2d(img, kernel_size, sigma):
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
x_kernel = pdf / pdf.sum()
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
img = F.pad(img, padding, mode="reflect")
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
return img
class SelfAttentionGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}),
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, scale, blur_sigma):
m = model.clone()
attn_scores = None
# TODO: make this work properly with chunked batches
# currently, we can only save the attn from one UNet call
def attn_and_record(q, k, v, extra_options):
nonlocal attn_scores
# if uncond, save the attention scores
heads = extra_options["n_heads"]
cond_or_uncond = extra_options["cond_or_uncond"]
b = q.shape[0] // len(cond_or_uncond)
if 1 in cond_or_uncond:
uncond_index = cond_or_uncond.index(1)
# do the entire attention operation, but save the attention scores to attn_scores
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads)
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
n_slices = heads * b
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
return out
else:
return optimized_attention(q, k, v, heads=heads)
def post_cfg_function(args):
nonlocal attn_scores
uncond_attn = attn_scores
sag_scale = scale
sag_sigma = blur_sigma
sag_threshold = 1.0
model = args["model"]
uncond_pred = args["uncond_denoised"]
uncond = args["uncond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
model_options = args["model_options"]
x = args["input"]
# create the adversarially blurred image
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
degraded_noised = degraded + x - uncond_pred
# call into the UNet
(sag, _) = samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
return cfg_result + (degraded - sag) * sag_scale
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
# from diffusers:
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
return (m, )
NODE_CLASS_MAPPINGS = {
"SelfAttentionGuidance": SelfAttentionGuidance,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SelfAttentionGuidance": "Self-Attention Guidance",
}

View File

@ -0,0 +1,46 @@
import torch
from comfy import utils
class SD_4XUpscale_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "images": ("IMAGE",),
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/upscale_diffusion"
def encode(self, images, positive, negative, scale_ratio, noise_augmentation):
width = max(1, round(images.shape[-2] * scale_ratio))
height = max(1, round(images.shape[-3] * scale_ratio))
pixels = utils.common_upscale((images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center")
out_cp = []
out_cn = []
for t in positive:
n = [t[0], t[1].copy()]
n[1]['concat_image'] = pixels
n[1]['noise_augmentation'] = noise_augmentation
out_cp.append(n)
for t in negative:
n = [t[0], t[1].copy()]
n[1]['concat_image'] = pixels
n[1]['noise_augmentation'] = noise_augmentation
out_cn.append(n)
latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
return (out_cp, out_cn, {"samples":latent})
NODE_CLASS_MAPPINGS = {
"SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning,
}

View File

@ -0,0 +1,61 @@
import torch
import comfy.utils
from comfy.nodes.common import MAX_RESOLUTION
from comfy import utils
def camera_embeddings(elevation, azimuth):
elevation = torch.as_tensor([elevation])
azimuth = torch.as_tensor([azimuth])
embeddings = torch.stack(
[
torch.deg2rad(
(90 - elevation) - (90)
), # Zero123 polar is 90-elevation
torch.sin(torch.deg2rad(azimuth)),
torch.cos(torch.deg2rad(azimuth)),
torch.deg2rad(
90 - torch.full_like(elevation, 0)
),
], dim=-1).unsqueeze(1)
return embeddings
class StableZero123_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 256, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
"azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/3d_models"
def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
t = vae.encode(encode_pixels)
cam_embeds = camera_embeddings(elevation, azimuth)
cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1)
positive = [[cond, {"concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})
NODE_CLASS_MAPPINGS = {
"StableZero123_Conditioning": StableZero123_Conditioning,
}

View File

@ -3,6 +3,7 @@ torchaudio
torchvision
torchdiffeq>=0.2.3
torchsde>=0.2.6
torchvision
einops>=0.6.0
open-clip-torch>=2.16.0
transformers>=4.29.1

View File

@ -28,18 +28,18 @@ version = '0.0.1'
"""
The package index to the torch built with AMD ROCm.
"""
amd_torch_index = "https://download.pytorch.org/whl/rocm5.6"
amd_torch_index = ("https://download.pytorch.org/whl/rocm5.6", "https://download.pytorch.org/whl/nightly/rocm5.7")
"""
The package index to torch built with CUDA.
Observe the CUDA version is in this URL.
"""
nvidia_torch_index = "https://download.pytorch.org/whl/cu121"
nvidia_torch_index = ("https://download.pytorch.org/whl/cu121", "https://download.pytorch.org/whl/nightly/cu121")
"""
The package index to torch built against CPU features.
"""
cpu_torch_index = "https://download.pytorch.org/whl/cpu"
cpu_torch_index = ("https://download.pytorch.org/whl/cpu", "https://download.pytorch.org/whl/nightly/cpu")
# xformers not required for new torch
@ -102,11 +102,11 @@ def _is_linux_arm64():
def dependencies() -> List[str]:
_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements.txt")).readlines()
# todo: also add all plugin dependencies
_alternative_indices = [amd_torch_index, nvidia_torch_index]
session = PipSession()
index_urls = ['https://pypi.org/simple']
# (stable, nightly) tuple
index_urls = [('https://pypi.org/simple', 'https://pypi.org/simple')]
# prefer nvidia over AMD because AM5/iGPU systems will have a valid ROCm device
if _is_nvidia():
index_urls += [nvidia_torch_index]
@ -118,6 +118,13 @@ def dependencies() -> List[str]:
if len(index_urls) == 1:
return _dependencies
if sys.version_info >= (3, 12):
# use the nightlies
index_urls = [nightly for (_, nightly) in index_urls]
_alternative_indices = [nightly for (_, nightly) in _alternative_indices]
else:
index_urls = [stable for (stable, _) in index_urls]
_alternative_indices = [stable for (stable, _) in _alternative_indices]
try:
# pip 23
finder = PackageFinder.create(LinkCollector(session, SearchScope([], index_urls, no_index=False)),
@ -149,7 +156,7 @@ setup(
description="",
author="",
version=version,
python_requires=">=3.9,<3.12",
python_requires=">=3.9,<3.13",
# todo: figure out how to include the web directory to eventually let main live inside the package
# todo: see https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ for more about adding plugins
packages=find_packages(exclude=[] if is_editable else ['custom_nodes']),

9
tests-ui/afterSetup.js Normal file
View File

@ -0,0 +1,9 @@
const { start } = require("./utils");
const lg = require("./utils/litegraph");
// Load things once per test file before to ensure its all warmed up for the tests
beforeAll(async () => {
lg.setup(global);
await start({ resetEnv: true });
lg.teardown(global);
});

View File

@ -2,8 +2,10 @@
const config = {
testEnvironment: "jsdom",
setupFiles: ["./globalSetup.js"],
setupFilesAfterEnv: ["./afterSetup.js"],
clearMocks: true,
resetModules: true,
testTimeout: 10000
};
module.exports = config;

View File

@ -52,7 +52,7 @@ describe("extensions", () => {
const nodeNames = Object.keys(defs);
const nodeCount = nodeNames.length;
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount);
for (let i = 0; i < nodeCount; i++) {
for (let i = 0; i < 10; i++) {
// It should be send the JS class and the original JSON definition
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0];
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1];
@ -133,7 +133,7 @@ describe("extensions", () => {
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2);
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1);
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2);
});
}, 15000);
it("allows custom nodeDefs and widgets to be registered", async () => {
const widgetMock = jest.fn((node, inputName, inputData, app) => {

View File

@ -1,7 +1,7 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start, createDefaultWorkflow } = require("../utils");
const { start, createDefaultWorkflow, getNodeDef, checkBeforeAndAfterReload } = require("../utils");
const lg = require("../utils/litegraph");
describe("group node", () => {
@ -273,7 +273,7 @@ describe("group node", () => {
let reroutes = [];
let prevNode = nodes.ckpt;
for(let i = 0; i < 5; i++) {
for (let i = 0; i < 5; i++) {
const reroute = ez.Reroute();
prevNode.outputs[0].connectTo(reroute.inputs[0]);
prevNode = reroute;
@ -283,7 +283,7 @@ describe("group node", () => {
const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]);
expect((await graph.toPrompt()).output).toEqual(getOutput());
group.menu["Convert to nodes"].call();
expect((await graph.toPrompt()).output).toEqual(getOutput());
});
@ -383,6 +383,43 @@ describe("group node", () => {
getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id])
);
});
test("groups can connect to each other via internal reroutes", async () => {
const { ez, graph, app } = await start();
const latent = ez.EmptyLatentImage();
const vae = ez.VAELoader();
const latentReroute = ez.Reroute();
const vaeReroute = ez.Reroute();
latent.outputs[0].connectTo(latentReroute.inputs[0]);
vae.outputs[0].connectTo(vaeReroute.inputs[0]);
const group1 = await convertToGroup(app, graph, "test", [latentReroute, vaeReroute]);
group1.menu.Clone.call();
expect(app.graph._nodes).toHaveLength(4);
const group2 = graph.find(app.graph._nodes[3]);
expect(group2.node.type).toEqual("workflow/test");
expect(group2.id).not.toEqual(group1.id);
group1.outputs.VAE.connectTo(group2.inputs.VAE);
group1.outputs.LATENT.connectTo(group2.inputs.LATENT);
const decode = ez.VAEDecode(group2.outputs.LATENT, group2.outputs.VAE);
const preview = ez.PreviewImage(decode.outputs[0]);
const output = {
[latent.id]: { inputs: { width: 512, height: 512, batch_size: 1 }, class_type: "EmptyLatentImage" },
[vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" },
[decode.id]: { inputs: { samples: [latent.id + "", 0], vae: [vae.id + "", 0] }, class_type: "VAEDecode" },
[preview.id]: { inputs: { images: [decode.id + "", 0] }, class_type: "PreviewImage" },
};
expect((await graph.toPrompt()).output).toEqual(output);
// Ensure missing connections dont cause errors
group2.inputs.VAE.disconnect();
delete output[decode.id].inputs.vae;
expect((await graph.toPrompt()).output).toEqual(output);
});
test("displays generated image on group node", async () => {
const { ez, graph, app } = await start();
const nodes = createDefaultWorkflow(ez, graph);
@ -642,6 +679,55 @@ describe("group node", () => {
2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" },
});
});
test("correctly handles widget inputs", async () => {
const { ez, graph, app } = await start();
const upscaleMethods = (await getNodeDef("ImageScaleBy")).input.required["upscale_method"][0];
const image = ez.LoadImage();
const scale1 = ez.ImageScaleBy(image.outputs[0]);
const scale2 = ez.ImageScaleBy(image.outputs[0]);
const preview1 = ez.PreviewImage(scale1.outputs[0]);
const preview2 = ez.PreviewImage(scale2.outputs[0]);
scale1.widgets.upscale_method.value = upscaleMethods[1];
scale1.widgets.upscale_method.convertToInput();
const group = await convertToGroup(app, graph, "test", [scale1, scale2]);
expect(group.inputs.length).toBe(3);
expect(group.inputs[0].input.type).toBe("IMAGE");
expect(group.inputs[1].input.type).toBe("IMAGE");
expect(group.inputs[2].input.type).toBe("COMBO");
// Ensure links are maintained
expect(group.inputs[0].connection?.originNode?.id).toBe(image.id);
expect(group.inputs[1].connection?.originNode?.id).toBe(image.id);
expect(group.inputs[2].connection).toBeFalsy();
// Ensure primitive gets correct type
const primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(group.inputs[2]);
expect(primitive.widgets.value.widget.options.values).toBe(upscaleMethods);
expect(primitive.widgets.value.value).toBe(upscaleMethods[1]); // Ensure value is copied
primitive.widgets.value.value = upscaleMethods[1];
await checkBeforeAndAfterReload(graph, async (r) => {
const scale1id = r ? `${group.id}:0` : scale1.id;
const scale2id = r ? `${group.id}:1` : scale2.id;
// Ensure widget value is applied to prompt
expect((await graph.toPrompt()).output).toStrictEqual({
[image.id]: { inputs: { image: "example.png", upload: "image" }, class_type: "LoadImage" },
[scale1id]: {
inputs: { upscale_method: upscaleMethods[1], scale_by: 1, image: [`${image.id}`, 0] },
class_type: "ImageScaleBy",
},
[scale2id]: {
inputs: { upscale_method: "nearest-exact", scale_by: 1, image: [`${image.id}`, 0] },
class_type: "ImageScaleBy",
},
[preview1.id]: { inputs: { images: [`${scale1id}`, 0] }, class_type: "PreviewImage" },
[preview2.id]: { inputs: { images: [`${scale2id}`, 0] }, class_type: "PreviewImage" },
});
});
});
test("adds widgets in node execution order", async () => {
const { ez, graph, app } = await start();
const scale = ez.LatentUpscale();
@ -815,4 +901,105 @@ describe("group node", () => {
expect(p2.widgets.control_after_generate.value).toBe("randomize");
expect(p2.widgets.control_filter_list.value).toBe("/.+/");
});
test("internal reroutes work with converted inputs and merge options", async () => {
const { ez, graph, app } = await start();
const vae = ez.VAELoader();
const latent = ez.EmptyLatentImage();
const decode = ez.VAEDecode(latent.outputs.LATENT, vae.outputs.VAE);
const scale = ez.ImageScale(decode.outputs.IMAGE);
ez.PreviewImage(scale.outputs.IMAGE);
const r1 = ez.Reroute();
const r2 = ez.Reroute();
latent.widgets.width.value = 64;
latent.widgets.height.value = 128;
latent.widgets.width.convertToInput();
latent.widgets.height.convertToInput();
latent.widgets.batch_size.convertToInput();
scale.widgets.width.convertToInput();
scale.widgets.height.convertToInput();
r1.inputs[0].input.label = "hbw";
r1.outputs[0].connectTo(latent.inputs.height);
r1.outputs[0].connectTo(latent.inputs.batch_size);
r1.outputs[0].connectTo(scale.inputs.width);
r2.inputs[0].input.label = "wh";
r2.outputs[0].connectTo(latent.inputs.width);
r2.outputs[0].connectTo(scale.inputs.height);
const group = await convertToGroup(app, graph, "test", [r1, r2, latent, decode, scale]);
expect(group.inputs[0].input.type).toBe("VAE");
expect(group.inputs[1].input.type).toBe("INT");
expect(group.inputs[2].input.type).toBe("INT");
const p1 = ez.PrimitiveNode();
const p2 = ez.PrimitiveNode();
p1.outputs[0].connectTo(group.inputs[1]);
p2.outputs[0].connectTo(group.inputs[2]);
expect(p1.widgets.value.widget.options?.min).toBe(16); // width/height min
expect(p1.widgets.value.widget.options?.max).toBe(4096); // batch max
expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min
expect(p2.widgets.value.widget.options?.max).toBe(8192); // width/height max
expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
expect(p1.widgets.value.value).toBe(128);
expect(p2.widgets.value.value).toBe(64);
p1.widgets.value.value = 16;
p2.widgets.value.value = 32;
await checkBeforeAndAfterReload(graph, async (r) => {
const id = (v) => (r ? `${group.id}:` : "") + v;
expect((await graph.toPrompt()).output).toStrictEqual({
1: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" },
[id(2)]: { inputs: { width: 32, height: 16, batch_size: 16 }, class_type: "EmptyLatentImage" },
[id(3)]: { inputs: { samples: [id(2), 0], vae: ["1", 0] }, class_type: "VAEDecode" },
[id(4)]: {
inputs: { upscale_method: "nearest-exact", width: 16, height: 32, crop: "disabled", image: [id(3), 0] },
class_type: "ImageScale",
},
5: { inputs: { images: [id(4), 0] }, class_type: "PreviewImage" },
});
});
});
test("converted inputs with linked widgets map values correctly on creation", async () => {
const { ez, graph, app } = await start();
const k1 = ez.KSampler();
const k2 = ez.KSampler();
k1.widgets.seed.convertToInput();
k2.widgets.seed.convertToInput();
const rr = ez.Reroute();
rr.outputs[0].connectTo(k1.inputs.seed);
rr.outputs[0].connectTo(k2.inputs.seed);
const group = await convertToGroup(app, graph, "test", [k1, k2, rr]);
expect(group.widgets.steps.value).toBe(20);
expect(group.widgets.cfg.value).toBe(8);
expect(group.widgets.scheduler.value).toBe("normal");
expect(group.widgets["KSampler steps"].value).toBe(20);
expect(group.widgets["KSampler cfg"].value).toBe(8);
expect(group.widgets["KSampler scheduler"].value).toBe("normal");
});
test("allow multiple of the same node type to be added", async () => {
const { ez, graph, app } = await start();
const nodes = [...Array(10)].map(() => ez.ImageScaleBy());
const group = await convertToGroup(app, graph, "test", nodes);
expect(group.inputs.length).toBe(10);
expect(group.outputs.length).toBe(10);
expect(group.widgets.length).toBe(20);
expect(group.widgets.map((w) => w.widget.name)).toStrictEqual(
[...Array(10)]
.map((_, i) => `${i > 0 ? "ImageScaleBy " : ""}${i > 1 ? i + " " : ""}`)
.flatMap((p) => [`${p}upscale_method`, `${p}scale_by`])
);
});
});

View File

@ -1,7 +1,13 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start, makeNodeDef, checkBeforeAndAfterReload, assertNotNullOrUndefined } = require("../utils");
const {
start,
makeNodeDef,
checkBeforeAndAfterReload,
assertNotNullOrUndefined,
createDefaultWorkflow,
} = require("../utils");
const lg = require("../utils/litegraph");
/**
@ -36,7 +42,7 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWi
if (controlWidgetCount) {
const controlWidget = primitive.widgets.control_after_generate;
expect(controlWidget.widget.type).toBe("combo");
if(widgetType === "combo") {
if (widgetType === "combo") {
const filterWidget = primitive.widgets.control_filter_list;
expect(filterWidget.widget.type).toBe("string");
}
@ -308,8 +314,8 @@ describe("widget inputs", () => {
const { ez } = await start({
mockNodeDefs: {
...makeNodeDef("TestNode1", {}, [["A", "B"]]),
...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true}] }),
...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true}] }),
...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true }] }),
...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true }] }),
},
});
@ -330,7 +336,7 @@ describe("widget inputs", () => {
const n1 = ez.TestNode1();
n1.widgets.example.convertToInput();
const p = ez.PrimitiveNode()
const p = ez.PrimitiveNode();
p.outputs[0].connectTo(n1.inputs[0]);
const value = p.widgets.value;
@ -380,7 +386,7 @@ describe("widget inputs", () => {
// Check random
control.value = "randomize";
filter.value = "/D/";
for(let i = 0; i < 100; i++) {
for (let i = 0; i < 100; i++) {
control["afterQueued"]();
expect(value.value === "D" || value.value === "DD").toBeTruthy();
}
@ -392,4 +398,160 @@ describe("widget inputs", () => {
control["afterQueued"]();
expect(value.value).toBe("B");
});
describe("reroutes", () => {
async function checkOutput(graph, values) {
expect((await graph.toPrompt()).output).toStrictEqual({
1: { inputs: { ckpt_name: "model1.safetensors" }, class_type: "CheckpointLoaderSimple" },
2: { inputs: { text: "positive", clip: ["1", 1] }, class_type: "CLIPTextEncode" },
3: { inputs: { text: "negative", clip: ["1", 1] }, class_type: "CLIPTextEncode" },
4: {
inputs: { width: values.width ?? 512, height: values.height ?? 512, batch_size: values?.batch_size ?? 1 },
class_type: "EmptyLatentImage",
},
5: {
inputs: {
seed: 0,
steps: 20,
cfg: 8,
sampler_name: "euler",
scheduler: values?.scheduler ?? "normal",
denoise: 1,
model: ["1", 0],
positive: ["2", 0],
negative: ["3", 0],
latent_image: ["4", 0],
},
class_type: "KSampler",
},
6: { inputs: { samples: ["5", 0], vae: ["1", 2] }, class_type: "VAEDecode" },
7: {
inputs: { filename_prefix: values.filename_prefix ?? "ComfyUI", images: ["6", 0] },
class_type: "SaveImage",
},
});
}
async function waitForWidget(node) {
// widgets are created slightly after the graph is ready
// hard to find an exact hook to get these so just wait for them to be ready
for (let i = 0; i < 10; i++) {
await new Promise((r) => setTimeout(r, 10));
if (node.widgets?.value) {
return;
}
}
}
it("can connect primitive via a reroute path to a widget input", async () => {
const { ez, graph } = await start();
const nodes = createDefaultWorkflow(ez, graph);
nodes.empty.widgets.width.convertToInput();
nodes.sampler.widgets.scheduler.convertToInput();
nodes.save.widgets.filename_prefix.convertToInput();
let widthReroute = ez.Reroute();
let schedulerReroute = ez.Reroute();
let fileReroute = ez.Reroute();
let widthNext = widthReroute;
let schedulerNext = schedulerReroute;
let fileNext = fileReroute;
for (let i = 0; i < 5; i++) {
let next = ez.Reroute();
widthNext.outputs[0].connectTo(next.inputs[0]);
widthNext = next;
next = ez.Reroute();
schedulerNext.outputs[0].connectTo(next.inputs[0]);
schedulerNext = next;
next = ez.Reroute();
fileNext.outputs[0].connectTo(next.inputs[0]);
fileNext = next;
}
widthNext.outputs[0].connectTo(nodes.empty.inputs.width);
schedulerNext.outputs[0].connectTo(nodes.sampler.inputs.scheduler);
fileNext.outputs[0].connectTo(nodes.save.inputs.filename_prefix);
let widthPrimitive = ez.PrimitiveNode();
let schedulerPrimitive = ez.PrimitiveNode();
let filePrimitive = ez.PrimitiveNode();
widthPrimitive.outputs[0].connectTo(widthReroute.inputs[0]);
schedulerPrimitive.outputs[0].connectTo(schedulerReroute.inputs[0]);
filePrimitive.outputs[0].connectTo(fileReroute.inputs[0]);
expect(widthPrimitive.widgets.value.value).toBe(512);
widthPrimitive.widgets.value.value = 1024;
expect(schedulerPrimitive.widgets.value.value).toBe("normal");
schedulerPrimitive.widgets.value.value = "simple";
expect(filePrimitive.widgets.value.value).toBe("ComfyUI");
filePrimitive.widgets.value.value = "ComfyTest";
await checkBeforeAndAfterReload(graph, async () => {
widthPrimitive = graph.find(widthPrimitive);
schedulerPrimitive = graph.find(schedulerPrimitive);
filePrimitive = graph.find(filePrimitive);
await waitForWidget(filePrimitive);
expect(widthPrimitive.widgets.length).toBe(2);
expect(schedulerPrimitive.widgets.length).toBe(3);
expect(filePrimitive.widgets.length).toBe(1);
await checkOutput(graph, {
width: 1024,
scheduler: "simple",
filename_prefix: "ComfyTest",
});
});
});
it("can connect primitive via a reroute path to multiple widget inputs", async () => {
const { ez, graph } = await start();
const nodes = createDefaultWorkflow(ez, graph);
nodes.empty.widgets.width.convertToInput();
nodes.empty.widgets.height.convertToInput();
nodes.empty.widgets.batch_size.convertToInput();
let reroute = ez.Reroute();
let prevReroute = reroute;
for (let i = 0; i < 5; i++) {
const next = ez.Reroute();
prevReroute.outputs[0].connectTo(next.inputs[0]);
prevReroute = next;
}
const r1 = ez.Reroute(prevReroute.outputs[0]);
const r2 = ez.Reroute(prevReroute.outputs[0]);
const r3 = ez.Reroute(r2.outputs[0]);
const r4 = ez.Reroute(r2.outputs[0]);
r1.outputs[0].connectTo(nodes.empty.inputs.width);
r3.outputs[0].connectTo(nodes.empty.inputs.height);
r4.outputs[0].connectTo(nodes.empty.inputs.batch_size);
let primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(reroute.inputs[0]);
expect(primitive.widgets.value.value).toBe(1);
primitive.widgets.value.value = 64;
await checkBeforeAndAfterReload(graph, async (r) => {
primitive = graph.find(primitive);
await waitForWidget(primitive);
// Ensure widget configs are merged
expect(primitive.widgets.value.widget.options?.min).toBe(16); // width/height min
expect(primitive.widgets.value.widget.options?.max).toBe(4096); // batch max
expect(primitive.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
await checkOutput(graph, {
width: 64,
height: 64,
batch_size: 64,
});
});
});
});
});

View File

@ -78,6 +78,14 @@ export class EzInput extends EzSlot {
this.input = input;
}
get connection() {
const link = this.node.node.inputs?.[this.index]?.link;
if (link == null) {
return null;
}
return new EzConnection(this.node.app, this.node.app.graph.links[link]);
}
disconnect() {
this.node.node.disconnectInput(this.index);
}
@ -117,7 +125,7 @@ export class EzOutput extends EzSlot {
const inp = input.input;
const inName = inp.name || inp.label || inp.type;
throw new Error(
`Connecting from ${input.node.node.type}[${inName}#${input.index}] -> ${this.node.node.type}[${
`Connecting from ${input.node.node.type}#${input.node.id}[${inName}#${input.index}] -> ${this.node.node.type}#${this.node.id}[${
this.output.name ?? this.output.type
}#${this.index}] failed.`
);
@ -179,6 +187,7 @@ export class EzWidget {
set value(v) {
this.widget.value = v;
this.widget.callback?.call?.(this.widget, v)
}
get isConvertedToInput() {
@ -319,7 +328,7 @@ export class EzGraph {
}
stringify() {
return JSON.stringify(this.app.graph.serialize(), undefined, "\t");
return JSON.stringify(this.app.graph.serialize(), undefined);
}
/**

View File

@ -104,3 +104,12 @@ export function createDefaultWorkflow(ez, graph) {
return { ckpt, pos, neg, empty, sampler, decode, save };
}
export async function getNodeDefs() {
const { api } = require("../../web/scripts/api");
return api.getNodeDefs();
}
export async function getNodeDef(nodeId) {
return (await getNodeDefs())[nodeId];
}

View File

@ -174,6 +174,11 @@ export class GroupNodeConfig {
node.index = i;
this.processNode(node, seenInputs, seenOutputs);
}
for (const p of this.#convertedToProcess) {
p();
}
this.#convertedToProcess = null;
await app.registerNodeDef("workflow/" + this.name, this.nodeDef);
}
@ -192,7 +197,10 @@ export class GroupNodeConfig {
if (!this.linksFrom[sourceNodeId]) {
this.linksFrom[sourceNodeId] = {};
}
this.linksFrom[sourceNodeId][sourceNodeSlot] = l;
if (!this.linksFrom[sourceNodeId][sourceNodeSlot]) {
this.linksFrom[sourceNodeId][sourceNodeSlot] = [];
}
this.linksFrom[sourceNodeId][sourceNodeSlot].push(l);
if (!this.linksTo[targetNodeId]) {
this.linksTo[targetNodeId] = {};
@ -230,11 +238,11 @@ export class GroupNodeConfig {
// Skip as its not linked
if (!linksFrom) return;
let type = linksFrom["0"][5];
let type = linksFrom["0"][0][5];
if (type === "COMBO") {
// Use the array items
const source = node.outputs[0].widget.name;
const fromTypeName = this.nodeData.nodes[linksFrom["0"][2]].type;
const fromTypeName = this.nodeData.nodes[linksFrom["0"][0][2]].type;
const fromType = globalDefs[fromTypeName];
const input = fromType.input.required[source] ?? fromType.input.optional[source];
type = input[0];
@ -258,10 +266,33 @@ export class GroupNodeConfig {
return null;
}
let config = {};
let rerouteType = "*";
if (linksFrom) {
const [, , id, slot] = linksFrom["0"];
rerouteType = this.nodeData.nodes[id].inputs[slot].type;
for (const [, , id, slot] of linksFrom["0"]) {
const node = this.nodeData.nodes[id];
const input = node.inputs[slot];
if (rerouteType === "*") {
rerouteType = input.type;
}
if (input.widget) {
const targetDef = globalDefs[node.type];
const targetWidget =
targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name];
const widget = [targetWidget[0], config];
const res = mergeIfValid(
{
widget,
},
targetWidget,
false,
null,
widget
);
config = res?.customConfig ?? config;
}
}
} else if (linksTo) {
const [id, slot] = linksTo["0"];
rerouteType = this.nodeData.nodes[id].outputs[slot].type;
@ -282,10 +313,11 @@ export class GroupNodeConfig {
}
}
config.forceInput = true;
return {
input: {
required: {
[rerouteType]: [rerouteType, {}],
[rerouteType]: [rerouteType, config],
},
},
output: [rerouteType],
@ -299,16 +331,17 @@ export class GroupNodeConfig {
getInputConfig(node, inputName, seenInputs, config, extra) {
let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName;
let key = name;
let prefix = "";
// Special handling for primitive to include the title if it is set rather than just "value"
if ((node.type === "PrimitiveNode" && node.title) || name in seenInputs) {
prefix = `${node.title ?? node.type} `;
name = `${prefix}${inputName}`;
key = name = `${prefix}${inputName}`;
if (name in seenInputs) {
name = `${prefix}${seenInputs[name]} ${inputName}`;
}
}
seenInputs[name] = (seenInputs[name] ?? 1) + 1;
seenInputs[key] = (seenInputs[key] ?? 1) + 1;
if (inputName === "seed" || inputName === "noise_seed") {
if (!extra) extra = {};
@ -420,10 +453,18 @@ export class GroupNodeConfig {
defaultInput: true,
});
this.nodeDef.input.required[name] = config;
this.newToOldWidgetMap[name] = { node, inputName };
if (!this.oldToNewWidgetMap[node.index]) {
this.oldToNewWidgetMap[node.index] = {};
}
this.oldToNewWidgetMap[node.index][inputName] = name;
inputMap[slots.length + i] = this.inputCount++;
}
}
#convertedToProcess = [];
processNodeInputs(node, seenInputs, inputs) {
const inputMapping = [];
@ -434,7 +475,11 @@ export class GroupNodeConfig {
const linksTo = this.linksTo[node.index] ?? {};
const inputMap = (this.oldToNewInputMap[node.index] = {});
this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs);
this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs);
// Converted inputs have to be processed after all other nodes as they'll be at the end of the list
this.#convertedToProcess.push(() =>
this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs)
);
return inputMapping;
}
@ -597,11 +642,19 @@ export class GroupNodeHandler {
const output = this.groupData.newToOldOutputMap[link.origin_slot];
let innerNode = this.innerNodes[output.node.index];
let l;
while (innerNode.type === "Reroute") {
while (innerNode?.type === "Reroute") {
l = innerNode.getInputLink(0);
innerNode = innerNode.getInputNode(0);
}
if (!innerNode) {
return null;
}
if (l && GroupNodeHandler.isGroupNode(innerNode)) {
return innerNode.updateLink(l);
}
link.origin_id = innerNode.id;
link.origin_slot = l?.origin_slot ?? output.slot;
return link;
@ -665,6 +718,8 @@ export class GroupNodeHandler {
top = newNode.pos[1];
}
if (!newNode.widgets) continue;
const map = this.groupData.oldToNewWidgetMap[innerNode.index];
if (map) {
const widgets = Object.keys(map);
@ -721,7 +776,7 @@ export class GroupNodeHandler {
}
};
const reconnectOutputs = () => {
const reconnectOutputs = (selectedIds) => {
for (let groupOutputId = 0; groupOutputId < node.outputs?.length; groupOutputId++) {
const output = node.outputs[groupOutputId];
if (!output.links) continue;
@ -861,7 +916,7 @@ export class GroupNodeHandler {
if (innerNode.type === "PrimitiveNode") {
innerNode.primitiveValue = newValue;
const primitiveLinked = this.groupData.primitiveToWidget[old.node.index];
for (const linked of primitiveLinked) {
for (const linked of primitiveLinked ?? []) {
const node = this.innerNodes[linked.nodeId];
const widget = node.widgets.find((w) => w.name === linked.inputName);
@ -870,6 +925,18 @@ export class GroupNodeHandler {
}
}
continue;
} else if (innerNode.type === "Reroute") {
const rerouteLinks = this.groupData.linksFrom[old.node.index];
for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) {
const node = this.innerNodes[targetNodeId];
const input = node.inputs[targetSlot];
if (input.widget) {
const widget = node.widgets?.find((w) => w.name === input.widget.name);
if (widget) {
widget.value = newValue;
}
}
}
}
const widget = innerNode.widgets?.find((w) => w.name === old.inputName);
@ -897,33 +964,58 @@ export class GroupNodeHandler {
this.node.widgets[targetWidgetIndex + i].value = primitiveNode.widgets[i].value;
}
}
return true;
}
populateReroute(node, nodeId, map) {
if (node.type !== "Reroute") return;
const link = this.groupData.linksFrom[nodeId]?.[0]?.[0];
if (!link) return;
const [, , targetNodeId, targetNodeSlot] = link;
const targetNode = this.groupData.nodeData.nodes[targetNodeId];
const inputs = targetNode.inputs;
const targetWidget = inputs?.[targetNodeSlot].widget;
if (!targetWidget) return;
const offset = inputs.length - (targetNode.widgets_values?.length ?? 0);
const v = targetNode.widgets_values?.[targetNodeSlot - offset];
if (v == null) return;
const widgetName = Object.values(map)[0];
const widget = this.node.widgets.find(w => w.name === widgetName);
if(widget) {
widget.value = v;
}
}
populateWidgets() {
if (!this.node.widgets) return;
for (let nodeId = 0; nodeId < this.groupData.nodeData.nodes.length; nodeId++) {
const node = this.groupData.nodeData.nodes[nodeId];
if (!node.widgets_values?.length) continue;
const map = this.groupData.oldToNewWidgetMap[nodeId];
const map = this.groupData.oldToNewWidgetMap[nodeId] ?? {};
const widgets = Object.keys(map);
if (!node.widgets_values?.length) {
// special handling for populating values into reroutes
// this allows primitives connect to them to pick up the correct value
this.populateReroute(node, nodeId, map);
continue;
}
let linkedShift = 0;
for (let i = 0; i < widgets.length; i++) {
const oldName = widgets[i];
const newName = map[oldName];
const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName);
const mainWidget = this.node.widgets[widgetIndex];
if (!newName) {
// New name will be null if its a converted widget
this.populatePrimitive(node, nodeId, oldName, i, linkedShift);
if (this.populatePrimitive(node, nodeId, oldName, i, linkedShift) || widgetIndex === -1) {
// Find the inner widget and shift by the number of linked widgets as they will have been removed too
const innerWidget = this.innerNodes[nodeId].widgets?.find((w) => w.name === oldName);
linkedShift += innerWidget.linkedWidgets?.length ?? 0;
continue;
linkedShift += innerWidget?.linkedWidgets?.length ?? 0;
}
if (widgetIndex === -1) {
continue;
}

View File

@ -33,6 +33,18 @@ function loadedImageToBlob(image) {
return blob;
}
function loadImage(imagePath) {
return new Promise((resolve, reject) => {
const image = new Image();
image.onload = function() {
resolve(image);
};
image.src = imagePath;
});
}
async function uploadMask(filepath, formData) {
await api.fetchApi('/upload/mask', {
method: 'POST',
@ -50,25 +62,25 @@ async function uploadMask(filepath, formData) {
ClipspaceDialog.invalidatePreview();
}
function prepareRGB(image, backupCanvas, backupCtx) {
function prepare_mask(image, maskCanvas, maskCtx) {
// paste mask data into alpha channel
backupCtx.drawImage(image, 0, 0, backupCanvas.width, backupCanvas.height);
const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height);
maskCtx.drawImage(image, 0, 0, maskCanvas.width, maskCanvas.height);
const maskData = maskCtx.getImageData(0, 0, maskCanvas.width, maskCanvas.height);
// refine mask image
for (let i = 0; i < backupData.data.length; i += 4) {
if(backupData.data[i+3] == 255)
backupData.data[i+3] = 0;
// invert mask
for (let i = 0; i < maskData.data.length; i += 4) {
if(maskData.data[i+3] == 255)
maskData.data[i+3] = 0;
else
backupData.data[i+3] = 255;
maskData.data[i+3] = 255;
backupData.data[i] = 0;
backupData.data[i+1] = 0;
backupData.data[i+2] = 0;
maskData.data[i] = 0;
maskData.data[i+1] = 0;
maskData.data[i+2] = 0;
}
backupCtx.globalCompositeOperation = 'source-over';
backupCtx.putImageData(backupData, 0, 0);
maskCtx.globalCompositeOperation = 'source-over';
maskCtx.putImageData(maskData, 0, 0);
}
class MaskEditorDialog extends ComfyDialog {
@ -155,10 +167,6 @@ class MaskEditorDialog extends ComfyDialog {
// If it is specified as relative, using it only as a hidden placeholder for padding is recommended
// to prevent anomalies where it exceeds a certain size and goes outside of the window.
var placeholder = document.createElement("div");
placeholder.style.position = "relative";
placeholder.style.height = "50px";
var bottom_panel = document.createElement("div");
bottom_panel.style.position = "absolute";
bottom_panel.style.bottom = "0px";
@ -180,18 +188,16 @@ class MaskEditorDialog extends ComfyDialog {
this.brush = brush;
this.element.appendChild(imgCanvas);
this.element.appendChild(maskCanvas);
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
this.element.appendChild(bottom_panel);
document.body.appendChild(brush);
var brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => {
this.brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => {
self.brush_size = event.target.value;
self.updateBrushPreview(self, null, null);
});
var clearButton = this.createLeftButton("Clear",
() => {
self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height);
self.backupCtx.clearRect(0, 0, self.backupCanvas.width, self.backupCanvas.height);
});
var cancelButton = this.createRightButton("Cancel", () => {
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
@ -207,40 +213,42 @@ class MaskEditorDialog extends ComfyDialog {
this.element.appendChild(imgCanvas);
this.element.appendChild(maskCanvas);
this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button
this.element.appendChild(bottom_panel);
bottom_panel.appendChild(clearButton);
bottom_panel.appendChild(this.saveButton);
bottom_panel.appendChild(cancelButton);
bottom_panel.appendChild(brush_size_slider);
bottom_panel.appendChild(this.brush_size_slider);
imgCanvas.style.position = "absolute";
maskCanvas.style.position = "absolute";
imgCanvas.style.position = "relative";
imgCanvas.style.top = "200";
imgCanvas.style.left = "0";
maskCanvas.style.position = "absolute";
maskCanvas.style.top = imgCanvas.style.top;
maskCanvas.style.left = imgCanvas.style.left;
}
show() {
async show() {
this.zoom_ratio = 1.0;
this.pan_x = 0;
this.pan_y = 0;
if(!this.is_layout_created) {
// layout
const imgCanvas = document.createElement('canvas');
const maskCanvas = document.createElement('canvas');
const backupCanvas = document.createElement('canvas');
imgCanvas.id = "imageCanvas";
maskCanvas.id = "maskCanvas";
backupCanvas.id = "backupCanvas";
this.setlayout(imgCanvas, maskCanvas);
// prepare content
this.imgCanvas = imgCanvas;
this.maskCanvas = maskCanvas;
this.backupCanvas = backupCanvas;
this.maskCtx = maskCanvas.getContext('2d');
this.backupCtx = backupCanvas.getContext('2d');
this.maskCtx = maskCanvas.getContext('2d', {willReadFrequently: true });
this.setEventHandler(maskCanvas);
@ -252,6 +260,8 @@ class MaskEditorDialog extends ComfyDialog {
mutations.forEach(function(mutation) {
if (mutation.type === 'attributes' && mutation.attributeName === 'style') {
if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') {
document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp);
self.brush.style.display = "none";
ComfyApp.onClipspaceEditorClosed();
}
@ -264,7 +274,8 @@ class MaskEditorDialog extends ComfyDialog {
observer.observe(this.element, config);
}
this.setImages(this.imgCanvas, this.backupCanvas);
// The keydown event needs to be reconfigured when closing the dialog as it gets removed.
document.addEventListener('keydown', MaskEditorDialog.handleKeyDown);
if(ComfyApp.clipspace_return_node) {
this.saveButton.innerText = "Save to node";
@ -275,97 +286,157 @@ class MaskEditorDialog extends ComfyDialog {
this.saveButton.disabled = false;
this.element.style.display = "block";
this.element.style.width = "85%";
this.element.style.margin = "0 7.5%";
this.element.style.height = "100vh";
this.element.style.top = "50%";
this.element.style.left = "42%";
this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority.
await this.setImages(this.imgCanvas);
this.is_visible = true;
}
isOpened() {
return this.element.style.display == "block";
}
setImages(imgCanvas, backupCanvas) {
const imgCtx = imgCanvas.getContext('2d');
const backupCtx = backupCanvas.getContext('2d');
invalidateCanvas(orig_image, mask_image) {
this.imgCanvas.width = orig_image.width;
this.imgCanvas.height = orig_image.height;
this.maskCanvas.width = orig_image.width;
this.maskCanvas.height = orig_image.height;
let imgCtx = this.imgCanvas.getContext('2d', {willReadFrequently: true });
let maskCtx = this.maskCanvas.getContext('2d', {willReadFrequently: true });
imgCtx.drawImage(orig_image, 0, 0, orig_image.width, orig_image.height);
prepare_mask(mask_image, this.maskCanvas, maskCtx);
}
async setImages(imgCanvas) {
let self = this;
const imgCtx = imgCanvas.getContext('2d', {willReadFrequently: true });
const maskCtx = this.maskCtx;
const maskCanvas = this.maskCanvas;
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height);
maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height);
// image load
const orig_image = new Image();
window.addEventListener("resize", () => {
// repositioning
imgCanvas.width = window.innerWidth - 250;
imgCanvas.height = window.innerHeight - 200;
// redraw image
let drawWidth = orig_image.width;
let drawHeight = orig_image.height;
if (orig_image.width > imgCanvas.width) {
drawWidth = imgCanvas.width;
drawHeight = (drawWidth / orig_image.width) * orig_image.height;
}
if (drawHeight > imgCanvas.height) {
drawHeight = imgCanvas.height;
drawWidth = (drawHeight / orig_image.height) * orig_image.width;
}
imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight);
// update mask
maskCanvas.width = drawWidth;
maskCanvas.height = drawHeight;
maskCanvas.style.top = imgCanvas.offsetTop + "px";
maskCanvas.style.left = imgCanvas.offsetLeft + "px";
backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height);
maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height);
});
const filepath = ComfyApp.clipspace.images;
const touched_image = new Image();
touched_image.onload = function() {
backupCanvas.width = touched_image.width;
backupCanvas.height = touched_image.height;
prepareRGB(touched_image, backupCanvas, backupCtx);
};
const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src)
alpha_url.searchParams.delete('channel');
alpha_url.searchParams.delete('preview');
alpha_url.searchParams.set('channel', 'a');
touched_image.src = alpha_url;
let mask_image = await loadImage(alpha_url);
// original image load
orig_image.onload = function() {
window.dispatchEvent(new Event('resize'));
};
const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src);
rgb_url.searchParams.delete('channel');
rgb_url.searchParams.set('channel', 'rgb');
orig_image.src = rgb_url;
this.image = orig_image;
this.image = new Image();
this.image.onload = function() {
maskCanvas.width = self.image.width;
maskCanvas.height = self.image.height;
self.invalidateCanvas(self.image, mask_image);
self.initializeCanvasPanZoom();
};
this.image.src = rgb_url;
}
setEventHandler(maskCanvas) {
maskCanvas.addEventListener("contextmenu", (event) => {
event.preventDefault();
});
initializeCanvasPanZoom() {
// set initialize
let drawWidth = this.image.width;
let drawHeight = this.image.height;
let width = this.element.clientWidth;
let height = this.element.clientHeight;
if (this.image.width > width) {
drawWidth = width;
drawHeight = (drawWidth / this.image.width) * this.image.height;
}
if (drawHeight > height) {
drawHeight = height;
drawWidth = (drawHeight / this.image.height) * this.image.width;
}
this.zoom_ratio = drawWidth/this.image.width;
const canvasX = (width - drawWidth) / 2;
const canvasY = (height - drawHeight) / 2;
this.pan_x = canvasX;
this.pan_y = canvasY;
this.invalidatePanZoom();
}
invalidatePanZoom() {
let raw_width = this.image.width * this.zoom_ratio;
let raw_height = this.image.height * this.zoom_ratio;
if(this.pan_x + raw_width < 10) {
this.pan_x = 10 - raw_width;
}
if(this.pan_y + raw_height < 10) {
this.pan_y = 10 - raw_height;
}
let width = `${raw_width}px`;
let height = `${raw_height}px`;
let left = `${this.pan_x}px`;
let top = `${this.pan_y}px`;
this.maskCanvas.style.width = width;
this.maskCanvas.style.height = height;
this.maskCanvas.style.left = left;
this.maskCanvas.style.top = top;
this.imgCanvas.style.width = width;
this.imgCanvas.style.height = height;
this.imgCanvas.style.left = left;
this.imgCanvas.style.top = top;
}
setEventHandler(maskCanvas) {
const self = this;
maskCanvas.addEventListener('wheel', (event) => this.handleWheelEvent(self,event));
maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event));
document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp);
maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event));
maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event));
maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; });
maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; });
document.addEventListener('keydown', MaskEditorDialog.handleKeyDown);
if(!this.handler_registered) {
maskCanvas.addEventListener("contextmenu", (event) => {
event.preventDefault();
});
this.element.addEventListener('wheel', (event) => this.handleWheelEvent(self,event));
this.element.addEventListener('pointermove', (event) => this.pointMoveEvent(self,event));
this.element.addEventListener('touchmove', (event) => this.pointMoveEvent(self,event));
this.element.addEventListener('dragstart', (event) => {
if(event.ctrlKey) {
event.preventDefault();
}
});
maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event));
maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event));
maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event));
maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; });
maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; });
document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp);
this.handler_registered = true;
}
}
brush_size = 10;
@ -378,8 +449,10 @@ class MaskEditorDialog extends ComfyDialog {
const self = MaskEditorDialog.instance;
if (event.key === ']') {
self.brush_size = Math.min(self.brush_size+2, 100);
self.brush_slider_input.value = self.brush_size;
} else if (event.key === '[') {
self.brush_size = Math.max(self.brush_size-2, 1);
self.brush_slider_input.value = self.brush_size;
} else if(event.key === 'Enter') {
self.save();
}
@ -389,6 +462,10 @@ class MaskEditorDialog extends ComfyDialog {
static handlePointerUp(event) {
event.preventDefault();
this.mousedown_x = null;
this.mousedown_y = null;
MaskEditorDialog.instance.drawing_mode = false;
}
@ -398,24 +475,70 @@ class MaskEditorDialog extends ComfyDialog {
var centerX = self.cursorX;
var centerY = self.cursorY;
brush.style.width = self.brush_size * 2 + "px";
brush.style.height = self.brush_size * 2 + "px";
brush.style.left = (centerX - self.brush_size) + "px";
brush.style.top = (centerY - self.brush_size) + "px";
brush.style.width = self.brush_size * 2 * this.zoom_ratio + "px";
brush.style.height = self.brush_size * 2 * this.zoom_ratio + "px";
brush.style.left = (centerX - self.brush_size * this.zoom_ratio) + "px";
brush.style.top = (centerY - self.brush_size * this.zoom_ratio) + "px";
}
handleWheelEvent(self, event) {
if(event.deltaY < 0)
self.brush_size = Math.min(self.brush_size+2, 100);
else
self.brush_size = Math.max(self.brush_size-2, 1);
event.preventDefault();
self.brush_slider_input.value = self.brush_size;
if(event.ctrlKey) {
// zoom canvas
if(event.deltaY < 0) {
this.zoom_ratio = Math.min(10.0, this.zoom_ratio+0.2);
}
else {
this.zoom_ratio = Math.max(0.2, this.zoom_ratio-0.2);
}
this.invalidatePanZoom();
}
else {
// adjust brush size
if(event.deltaY < 0)
this.brush_size = Math.min(this.brush_size+2, 100);
else
this.brush_size = Math.max(this.brush_size-2, 1);
this.brush_slider_input.value = this.brush_size;
this.updateBrushPreview(this);
}
}
pointMoveEvent(self, event) {
this.cursorX = event.pageX;
this.cursorY = event.pageY;
self.updateBrushPreview(self);
if(event.ctrlKey) {
event.preventDefault();
self.pan_move(self, event);
}
}
pan_move(self, event) {
if(event.buttons == 1) {
if(this.mousedown_x) {
let deltaX = this.mousedown_x - event.clientX;
let deltaY = this.mousedown_y - event.clientY;
self.pan_x = this.mousedown_pan_x - deltaX;
self.pan_y = this.mousedown_pan_y - deltaY;
self.invalidatePanZoom();
}
}
}
draw_move(self, event) {
if(event.ctrlKey) {
return;
}
event.preventDefault();
this.cursorX = event.pageX;
@ -439,6 +562,9 @@ class MaskEditorDialog extends ComfyDialog {
y = event.targetTouches[0].clientY - maskRect.top;
}
x /= self.zoom_ratio;
y /= self.zoom_ratio;
var brush_size = this.brush_size;
if(event instanceof PointerEvent && event.pointerType == 'pen') {
brush_size *= event.pressure;
@ -489,8 +615,8 @@ class MaskEditorDialog extends ComfyDialog {
}
else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) {
const maskRect = self.maskCanvas.getBoundingClientRect();
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio;
const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio;
var brush_size = this.brush_size;
if(event instanceof PointerEvent && event.pointerType == 'pen') {
@ -540,6 +666,17 @@ class MaskEditorDialog extends ComfyDialog {
}
handlePointerDown(self, event) {
if(event.ctrlKey) {
if (event.buttons == 1) {
this.mousedown_x = event.clientX;
this.mousedown_y = event.clientY;
this.mousedown_pan_x = this.pan_x;
this.mousedown_pan_y = this.pan_y;
}
return;
}
var brush_size = this.brush_size;
if(event instanceof PointerEvent && event.pointerType == 'pen') {
brush_size *= event.pressure;
@ -551,8 +688,8 @@ class MaskEditorDialog extends ComfyDialog {
event.preventDefault();
const maskRect = self.maskCanvas.getBoundingClientRect();
const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left;
const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top;
const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio;
const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio;
self.maskCtx.beginPath();
if (event.button == 0) {
@ -570,15 +707,18 @@ class MaskEditorDialog extends ComfyDialog {
}
async save() {
const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true});
const backupCanvas = document.createElement('canvas');
const backupCtx = backupCanvas.getContext('2d', {willReadFrequently:true});
backupCanvas.width = this.image.width;
backupCanvas.height = this.image.height;
backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height);
backupCtx.clearRect(0,0, backupCanvas.width, backupCanvas.height);
backupCtx.drawImage(this.maskCanvas,
0, 0, this.maskCanvas.width, this.maskCanvas.height,
0, 0, this.backupCanvas.width, this.backupCanvas.height);
0, 0, backupCanvas.width, backupCanvas.height);
// paste mask data into alpha channel
const backupData = backupCtx.getImageData(0, 0, this.backupCanvas.width, this.backupCanvas.height);
const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height);
// refine mask image
for (let i = 0; i < backupData.data.length; i += 4) {
@ -615,7 +755,7 @@ class MaskEditorDialog extends ComfyDialog {
ComfyApp.clipspace.widgets[index].value = item;
}
const dataURL = this.backupCanvas.toDataURL();
const dataURL = backupCanvas.toDataURL();
const blob = dataURLToBlob(dataURL);
let original_url = new URL(this.image.src);

View File

@ -1,10 +1,11 @@
import { app } from "../../scripts/app.js";
import { mergeIfValid, getWidgetConfig, setWidgetConfig } from "./widgetInputs.js";
// Node that allows you to redirect connections for cleaner graphs
app.registerExtension({
name: "Comfy.RerouteNode",
registerCustomNodes() {
registerCustomNodes(app) {
class RerouteNode {
constructor() {
if (!this.properties) {
@ -16,6 +17,12 @@ app.registerExtension({
this.addInput("", "*");
this.addOutput(this.properties.showOutputText ? "*" : "", "*");
this.onAfterGraphConfigured = function () {
requestAnimationFrame(() => {
this.onConnectionsChange(LiteGraph.INPUT, null, true, null);
});
};
this.onConnectionsChange = function (type, index, connected, link_info) {
this.applyOrientation();
@ -47,6 +54,7 @@ app.registerExtension({
const linkId = currentNode.inputs[0].link;
if (linkId !== null) {
const link = app.graph.links[linkId];
if (!link) return;
const node = app.graph.getNodeById(link.origin_id);
const type = node.constructor.type;
if (type === "Reroute") {
@ -54,8 +62,7 @@ app.registerExtension({
// We've found a circle
currentNode.disconnectInput(link.target_slot);
currentNode = null;
}
else {
} else {
// Move the previous node
currentNode = node;
}
@ -94,8 +101,11 @@ app.registerExtension({
updateNodes.push(node);
} else {
// We've found an output
const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null;
if (inputType && nodeOutType !== inputType) {
const nodeOutType =
node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type
? node.inputs[link.target_slot].type
: null;
if (inputType && inputType !== "*" && nodeOutType !== inputType) {
// The output doesnt match our input so disconnect it
node.disconnectInput(link.target_slot);
} else {
@ -111,6 +121,9 @@ app.registerExtension({
const displayType = inputType || outputType || "*";
const color = LGraphCanvas.link_type_colors[displayType];
let widgetConfig;
let targetWidget;
let widgetType;
// Update the types of each node
for (const node of updateNodes) {
// If we dont have an input type we are always wildcard but we'll show the output type
@ -125,10 +138,38 @@ app.registerExtension({
const link = app.graph.links[l];
if (link) {
link.color = color;
if (app.configuringGraph) continue;
const targetNode = app.graph.getNodeById(link.target_id);
const targetInput = targetNode.inputs?.[link.target_slot];
if (targetInput?.widget) {
const config = getWidgetConfig(targetInput);
if (!widgetConfig) {
widgetConfig = config[1] ?? {};
widgetType = config[0];
}
if (!targetWidget) {
targetWidget = targetNode.widgets?.find((w) => w.name === targetInput.widget.name);
}
const merged = mergeIfValid(targetInput, [config[0], widgetConfig]);
if (merged.customConfig) {
widgetConfig = merged.customConfig;
}
}
}
}
}
for (const node of updateNodes) {
if (widgetConfig && outputType) {
node.inputs[0].widget = { name: "value" };
setWidgetConfig(node.inputs[0], [widgetType ?? displayType, widgetConfig], targetWidget);
} else {
setWidgetConfig(node.inputs[0], null);
}
}
if (inputNode) {
const link = app.graph.links[inputNode.inputs[0].link];
if (link) {
@ -173,8 +214,8 @@ app.registerExtension({
},
{
// naming is inverted with respect to LiteGraphNode.horizontal
// LiteGraphNode.horizontal == true means that
// each slot in the inputs and outputs are layed out horizontally,
// LiteGraphNode.horizontal == true means that
// each slot in the inputs and outputs are layed out horizontally,
// which is the opposite of the visual orientation of the inputs and outputs as a node
content: "Set " + (this.properties.horizontal ? "Horizontal" : "Vertical"),
callback: () => {
@ -187,7 +228,7 @@ app.registerExtension({
applyOrientation() {
this.horizontal = this.properties.horizontal;
if (this.horizontal) {
// we correct the input position, because LiteGraphNode.horizontal
// we correct the input position, because LiteGraphNode.horizontal
// doesn't account for title presence
// which reroute nodes don't have
this.inputs[0].pos = [this.size[0] / 2, 0];

View File

@ -1,5 +1,5 @@
import { app } from "../../scripts/app.js";
import { applyTextReplacements } from "../../scripts/utils.js";
// Use widget values and dates in output filenames
app.registerExtension({
@ -7,84 +7,19 @@ app.registerExtension({
async beforeRegisterNodeDef(nodeType, nodeData, app) {
if (nodeData.name === "SaveImage") {
const onNodeCreated = nodeType.prototype.onNodeCreated;
// Simple date formatter
const parts = {
d: (d) => d.getDate(),
M: (d) => d.getMonth() + 1,
h: (d) => d.getHours(),
m: (d) => d.getMinutes(),
s: (d) => d.getSeconds(),
};
const format =
Object.keys(parts)
.map((k) => k + k + "?")
.join("|") + "|yyy?y?";
function formatDate(text, date) {
return text.replace(new RegExp(format, "g"), function (text) {
if (text === "yy") return (date.getFullYear() + "").substring(2);
if (text === "yyyy") return date.getFullYear();
if (text[0] in parts) {
const p = parts[text[0]](date);
return (p + "").padStart(text.length, "0");
}
return text;
});
}
// When the SaveImage node is created we want to override the serialization of the output name widget to run our S&R
// When the SaveImage node is created we want to override the serialization of the output name widget to run our S&R
nodeType.prototype.onNodeCreated = function () {
const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;
const widget = this.widgets.find((w) => w.name === "filename_prefix");
widget.serializeValue = () => {
return widget.value.replace(/%([^%]+)%/g, function (match, text) {
const split = text.split(".");
if (split.length !== 2) {
// Special handling for dates
if (split[0].startsWith("date:")) {
return formatDate(split[0].substring(5), new Date());
}
if (text !== "width" && text !== "height") {
// Dont warn on standard replacements
console.warn("Invalid replacement pattern", text);
}
return match;
}
// Find node with matching S&R property name
let nodes = app.graph._nodes.filter((n) => n.properties?.["Node name for S&R"] === split[0]);
// If we cant, see if there is a node with that title
if (!nodes.length) {
nodes = app.graph._nodes.filter((n) => n.title === split[0]);
}
if (!nodes.length) {
console.warn("Unable to find node", split[0]);
return match;
}
if (nodes.length > 1) {
console.warn("Multiple nodes matched", split[0], "using first match");
}
const node = nodes[0];
const widget = node.widgets?.find((w) => w.name === split[1]);
if (!widget) {
console.warn("Unable to find widget", split[1], "on node", split[0], node);
return match;
}
return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_");
});
return applyTextReplacements(app, widget.value);
};
return r;
};
} else {
// When any other node is created add a property to alias the node
// When any other node is created add a property to alias the node
const onNodeCreated = nodeType.prototype.onNodeCreated;
nodeType.prototype.onNodeCreated = function () {
const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined;

View File

@ -71,24 +71,21 @@ function graphEqual(a, b, root = true) {
}
const undoRedo = async (e) => {
const updateState = async (source, target) => {
const prevState = source.pop();
if (prevState) {
target.push(activeState);
isOurLoad = true;
await app.loadGraphData(prevState, false);
activeState = prevState;
}
}
if (e.ctrlKey || e.metaKey) {
if (e.key === "y") {
const prevState = redo.pop();
if (prevState) {
undo.push(activeState);
isOurLoad = true;
await app.loadGraphData(prevState);
activeState = prevState;
}
updateState(redo, undo);
return true;
} else if (e.key === "z") {
const prevState = undo.pop();
if (prevState) {
redo.push(activeState);
isOurLoad = true;
await app.loadGraphData(prevState);
activeState = prevState;
}
updateState(undo, redo);
return true;
}
}

View File

@ -1,10 +1,16 @@
import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js";
import { app } from "../../scripts/app.js";
import { applyTextReplacements } from "../../scripts/utils.js";
const CONVERTED_TYPE = "converted-widget";
const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"];
const CONFIG = Symbol();
const GET_CONFIG = Symbol();
const TARGET = Symbol(); // Used for reroutes to specify the real target widget
export function getWidgetConfig(slot) {
return slot.widget[CONFIG] ?? slot.widget[GET_CONFIG]();
}
function getConfig(widgetName) {
const { nodeData } = this.constructor;
@ -100,7 +106,6 @@ function getWidgetType(config) {
return { type };
}
function isValidCombo(combo, obj) {
// New input isnt a combo
if (!(obj instanceof Array)) {
@ -121,6 +126,31 @@ function isValidCombo(combo, obj) {
return true;
}
export function setWidgetConfig(slot, config, target) {
if (!slot.widget) return;
if (config) {
slot.widget[GET_CONFIG] = () => config;
slot.widget[TARGET] = target;
} else {
delete slot.widget;
}
if (slot.link) {
const link = app.graph.links[slot.link];
if (link) {
const originNode = app.graph.getNodeById(link.origin_id);
if (originNode.type === "PrimitiveNode") {
if (config) {
originNode.recreateWidget();
} else if(!app.configuringGraph) {
originNode.disconnectOutput(0);
originNode.onLastDisconnect();
}
}
}
}
}
export function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) {
if (!config1) {
config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG]();
@ -150,7 +180,7 @@ export function mergeIfValid(output, config2, forceUpdate, recreateWidget, confi
const isNumber = config1[0] === "INT" || config1[0] === "FLOAT";
for (const k of keys.values()) {
if (k !== "default" && k !== "forceInput" && k !== "defaultInput") {
if (k !== "default" && k !== "forceInput" && k !== "defaultInput" && k !== "control_after_generate" && k !== "multiline") {
let v1 = config1[1][k];
let v2 = config2[1]?.[k];
@ -405,11 +435,16 @@ app.registerExtension({
};
},
registerCustomNodes() {
const replacePropertyName = "Run widget replace on values";
class PrimitiveNode {
constructor() {
this.addOutput("connect to widget input", "*");
this.serialize_widgets = true;
this.isVirtualNode = true;
if (!this.properties || !(replacePropertyName in this.properties)) {
this.addProperty(replacePropertyName, false, "boolean");
}
}
applyToGraph(extraLinks = []) {
@ -430,18 +465,29 @@ app.registerExtension({
}
let links = [...get_links(this).map((l) => app.graph.links[l]), ...extraLinks];
let v = this.widgets?.[0].value;
if(v && this.properties[replacePropertyName]) {
v = applyTextReplacements(app, v);
}
// For each output link copy our value over the original widget value
for (const linkInfo of links) {
const node = this.graph.getNodeById(linkInfo.target_id);
const input = node.inputs[linkInfo.target_slot];
const widgetName = input.widget.name;
if (widgetName) {
const widget = node.widgets.find((w) => w.name === widgetName);
if (widget) {
widget.value = this.widgets[0].value;
if (widget.callback) {
widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {});
}
let widget;
if (input.widget[TARGET]) {
widget = input.widget[TARGET];
} else {
const widgetName = input.widget.name;
if (widgetName) {
widget = node.widgets.find((w) => w.name === widgetName);
}
}
if (widget) {
widget.value = v;
if (widget.callback) {
widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {});
}
}
}
@ -494,14 +540,13 @@ app.registerExtension({
this.#mergeWidgetConfig();
if (!links?.length) {
this.#onLastDisconnect();
this.onLastDisconnect();
}
}
}
onConnectOutput(slot, type, input, target_node, target_slot) {
// Fires before the link is made allowing us to reject it if it isn't valid
// No widget, we cant connect
if (!input.widget) {
if (!(input.type in ComfyWidgets)) return false;
@ -519,6 +564,10 @@ app.registerExtension({
#onFirstConnection(recreating) {
// First connection can fire before the graph is ready on initial load so random things can be missing
if (!this.outputs[0].links) {
this.onLastDisconnect();
return;
}
const linkId = this.outputs[0].links[0];
const link = this.graph.links[linkId];
if (!link) return;
@ -546,10 +595,10 @@ app.registerExtension({
this.outputs[0].name = type;
this.outputs[0].widget = widget;
this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating);
this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating, widget[TARGET]);
}
#createWidget(inputData, node, widgetName, recreating) {
#createWidget(inputData, node, widgetName, recreating, targetWidget) {
let type = inputData[0];
if (type instanceof Array) {
@ -563,7 +612,9 @@ app.registerExtension({
widget = this.addWidget(type, "value", null, () => {}, {});
}
if (node?.widgets && widget) {
if (targetWidget) {
widget.value = targetWidget.value;
} else if (node?.widgets && widget) {
const theirWidget = node.widgets.find((w) => w.name === widgetName);
if (theirWidget) {
widget.value = theirWidget.value;
@ -577,11 +628,19 @@ app.registerExtension({
}
addValueControlWidgets(this, widget, control_value, undefined, inputData);
let filter = this.widgets_values?.[2];
if(filter && this.widgets.length === 3) {
if (filter && this.widgets.length === 3) {
this.widgets[2].value = filter;
}
}
// Restore any saved control values
const controlValues = this.controlValues;
if(this.lastType === this.widgets[0].type && controlValues?.length === this.widgets.length - 1) {
for(let i = 0; i < controlValues.length; i++) {
this.widgets[i + 1].value = controlValues[i];
}
}
// When our value changes, update other widgets to reflect our changes
// e.g. so LoadImage shows correct image
const callback = widget.callback;
@ -610,12 +669,14 @@ app.registerExtension({
}
}
#recreateWidget() {
const values = this.widgets.map((w) => w.value);
recreateWidget() {
const values = this.widgets?.map((w) => w.value);
this.#removeWidgets();
this.#onFirstConnection(true);
for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i];
return this.widgets[0];
if (values?.length) {
for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i];
}
return this.widgets?.[0];
}
#mergeWidgetConfig() {
@ -631,7 +692,7 @@ app.registerExtension({
if (links?.length < 2 && hasConfig) {
// Copy the widget options from the source
if (links.length) {
this.#recreateWidget();
this.recreateWidget();
}
return;
@ -657,7 +718,7 @@ app.registerExtension({
// Only allow connections where the configs match
const output = this.outputs[0];
const config2 = input.widget[GET_CONFIG]();
return !!mergeIfValid.call(this, output, config2, forceUpdate, this.#recreateWidget);
return !!mergeIfValid.call(this, output, config2, forceUpdate, this.recreateWidget);
}
#removeWidgets() {
@ -668,11 +729,20 @@ app.registerExtension({
w.onRemove();
}
}
// Temporarily store the current values in case the node is being recreated
// e.g. by group node conversion
this.controlValues = [];
this.lastType = this.widgets[0]?.type;
for(let i = 1; i < this.widgets.length; i++) {
this.controlValues.push(this.widgets[i].value);
}
setTimeout(() => { delete this.lastType; delete this.controlValues }, 15);
this.widgets.length = 0;
}
}
#onLastDisconnect() {
onLastDisconnect() {
// We cant remove + re-add the output here as if you drag a link over the same link
// it removes, then re-adds, causing it to break
this.outputs[0].type = "*";

View File

@ -48,7 +48,7 @@
EVENT_LINK_COLOR: "#A86",
CONNECTING_LINK_COLOR: "#AFA",
MAX_NUMBER_OF_NODES: 1000, //avoid infinite loops
MAX_NUMBER_OF_NODES: 10000, //avoid infinite loops
DEFAULT_POSITION: [100, 100], //default node position
VALID_SHAPES: ["default", "box", "round", "card"], //,"circle"
@ -3788,16 +3788,42 @@
/**
* returns the bounding of the object, used for rendering purposes
* bounding is: [topleft_cornerx, topleft_cornery, width, height]
* @method getBounding
* @return {Float32Array[4]} the total size
* @param out {Float32Array[4]?} [optional] a place to store the output, to free garbage
* @param compute_outer {boolean?} [optional] set to true to include the shadow and connection points in the bounding calculation
* @return {Float32Array[4]} the bounding box in format of [topleft_cornerx, topleft_cornery, width, height]
*/
LGraphNode.prototype.getBounding = function(out) {
LGraphNode.prototype.getBounding = function(out, compute_outer) {
out = out || new Float32Array(4);
out[0] = this.pos[0] - 4;
out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT;
out[2] = this.flags.collapsed ? (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) : this.size[0] + 4;
out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT;
const nodePos = this.pos;
const isCollapsed = this.flags.collapsed;
const nodeSize = this.size;
let left_offset = 0;
// 1 offset due to how nodes are rendered
let right_offset = 1 ;
let top_offset = 0;
let bottom_offset = 0;
if (compute_outer) {
// 4 offset for collapsed node connection points
left_offset = 4;
// 6 offset for right shadow and collapsed node connection points
right_offset = 6 + left_offset;
// 4 offset for collapsed nodes top connection points
top_offset = 4;
// 5 offset for bottom shadow and collapsed node connection points
bottom_offset = 5 + top_offset;
}
out[0] = nodePos[0] - left_offset;
out[1] = nodePos[1] - LiteGraph.NODE_TITLE_HEIGHT - top_offset;
out[2] = isCollapsed ?
(this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) + right_offset :
nodeSize[0] + right_offset;
out[3] = isCollapsed ?
LiteGraph.NODE_TITLE_HEIGHT + bottom_offset :
nodeSize[1] + LiteGraph.NODE_TITLE_HEIGHT + bottom_offset;
if (this.onBounding) {
this.onBounding(out);
@ -7674,7 +7700,7 @@ LGraphNode.prototype.executeAction = function(action)
continue;
}
if (!overlapBounding(this.visible_area, n.getBounding(temp))) {
if (!overlapBounding(this.visible_area, n.getBounding(temp, true))) {
continue;
} //out of the visible area
@ -11336,6 +11362,7 @@ LGraphNode.prototype.executeAction = function(action)
name_element.innerText = title;
var value_element = dialog.querySelector(".value");
value_element.value = value;
value_element.select();
var input = value_element;
input.addEventListener("keydown", function(e) {

View File

@ -1559,9 +1559,12 @@ export class ComfyApp {
/**
* Populates the graph with the specified workflow data
* @param {*} graphData A serialized graph object
* @param { boolean } clean If the graph state, e.g. images, should be cleared
*/
async loadGraphData(graphData) {
this.clean();
async loadGraphData(graphData, clean = true) {
if (clean !== false) {
this.clean();
}
let reset_invalid_values = false;
if (!graphData) {
@ -1771,15 +1774,26 @@ export class ComfyApp {
if (parent?.updateLink) {
link = parent.updateLink(link);
}
inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
if (link) {
inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)];
}
}
}
}
output[String(node.id)] = {
let node_data = {
inputs,
class_type: node.comfyClass,
};
if (this.ui.settings.getSettingValue("Comfy.DevMode")) {
// Ignored by the backend.
node_data["_meta"] = {
title: node.title,
}
}
output[String(node.id)] = node_data;
}
}
@ -2006,12 +2020,8 @@ export class ComfyApp {
async refreshComboInNodes() {
const defs = await api.getNodeDefs();
for(const nodeId in LiteGraph.registered_node_types) {
const node = LiteGraph.registered_node_types[nodeId];
const nodeDef = defs[nodeId];
if(!nodeDef) continue;
node.nodeData = nodeDef;
for (const nodeId in defs) {
this.registerNodeDef(nodeId, defs[nodeId]);
}
for(let nodeNum in this.graph._nodes) {

View File

@ -177,6 +177,7 @@ LGraphCanvas.prototype.computeVisibleNodes = function () {
for (const w of node.widgets) {
if (w.element) {
w.element.hidden = hidden;
w.element.style.display = hidden ? "none" : undefined;
if (hidden) {
w.options.onHide?.(w);
}

67
web/scripts/utils.js Normal file
View File

@ -0,0 +1,67 @@
// Simple date formatter
const parts = {
d: (d) => d.getDate(),
M: (d) => d.getMonth() + 1,
h: (d) => d.getHours(),
m: (d) => d.getMinutes(),
s: (d) => d.getSeconds(),
};
const format =
Object.keys(parts)
.map((k) => k + k + "?")
.join("|") + "|yyy?y?";
function formatDate(text, date) {
return text.replace(new RegExp(format, "g"), function (text) {
if (text === "yy") return (date.getFullYear() + "").substring(2);
if (text === "yyyy") return date.getFullYear();
if (text[0] in parts) {
const p = parts[text[0]](date);
return (p + "").padStart(text.length, "0");
}
return text;
});
}
export function applyTextReplacements(app, value) {
return value.replace(/%([^%]+)%/g, function (match, text) {
const split = text.split(".");
if (split.length !== 2) {
// Special handling for dates
if (split[0].startsWith("date:")) {
return formatDate(split[0].substring(5), new Date());
}
if (text !== "width" && text !== "height") {
// Dont warn on standard replacements
console.warn("Invalid replacement pattern", text);
}
return match;
}
// Find node with matching S&R property name
let nodes = app.graph._nodes.filter((n) => n.properties?.["Node name for S&R"] === split[0]);
// If we cant, see if there is a node with that title
if (!nodes.length) {
nodes = app.graph._nodes.filter((n) => n.title === split[0]);
}
if (!nodes.length) {
console.warn("Unable to find node", split[0]);
return match;
}
if (nodes.length > 1) {
console.warn("Multiple nodes matched", split[0], "using first match");
}
const node = nodes[0];
const widget = node.widgets?.find((w) => w.name === split[1]);
if (!widget) {
console.warn("Unable to find widget", split[1], "on node", split[0], node);
return match;
}
return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_");
});
}