merging master into filesubflows

This commit is contained in:
Sammy Franklin 2023-10-23 10:09:34 -07:00
commit 1aa10090e6
59 changed files with 8063 additions and 949 deletions

26
.github/workflows/test-ui.yaml vendored Normal file
View File

@ -0,0 +1,26 @@
name: Tests CI
on: [push, pull_request]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v3
with:
node-version: 18
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Run Tests
run: |
npm ci
npm run test:generate
npm test
working-directory: ./tests-ui

View File

@ -41,13 +41,13 @@ jobs:
- shell: bash
run: |
echo "@echo off
..\python_embeded\python.exe .\update.py ..\ComfyUI\
echo
..\python_embeded\python.exe .\update.py ..\ComfyUI\\
echo -
echo This will try to update pytorch and all python dependencies, if you get an error wait for pytorch/xformers to fix their stuff
echo You should not be running this anyways unless you really have to
echo
echo -
echo If you just want to update normally, close this and run update_comfyui.bat instead.
echo
echo -
pause
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat

1
.gitignore vendored
View File

@ -15,3 +15,4 @@ venv/
/web/extensions/*
!/web/extensions/logging.js.example
!/web/extensions/core/
/tests-ui/data/object_info.json

View File

@ -46,6 +46,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
| Ctrl + S | Save workflow |
| Ctrl + O | Load workflow |
| Ctrl + A | Select all nodes |
| Alt + C | Collapse/uncollapse selected nodes |
| Ctrl + M | Mute/unmute selected nodes |
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| Delete/Backspace | Delete selected nodes |
@ -69,7 +70,7 @@ Ctrl can also be replaced with Cmd instead for macOS users
There is a portable standalone build for Windows that should work for running on Nvidia GPUs or for running on your CPU only on the [releases page](https://github.com/comfyanonymous/ComfyUI/releases).
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/download/latest/ComfyUI_windows_portable_nvidia_cu118_or_cpu.7z)
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/download/latest/ComfyUI_windows_portable_nvidia_cu121_or_cpu.7z)
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
@ -89,6 +90,8 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
Put your VAE in: models/vae
Note: pytorch does not support python 3.12 yet so make sure your python version is 3.11 or earlier.
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:

View File

@ -34,8 +34,7 @@ class ControlNet(nn.Module):
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
use_bf16=False,
dtype=torch.float32,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
@ -108,8 +107,7 @@ class ControlNet(nn.Module):
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.dtype = th.bfloat16 if use_bf16 else self.dtype
self.dtype = dtype
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample

View File

@ -39,6 +39,7 @@ parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORI
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
@ -52,6 +53,8 @@ fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")

View File

@ -92,8 +92,11 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
else:
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
else:
return None
clip = ClipVisionModel(json_config)
m, u = clip.load_sd(sd)
if len(m) > 0:

View File

@ -292,8 +292,8 @@ def load_controlnet(ckpt_path, model=None):
controlnet_config = None
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
use_fp16 = comfy.model_management.should_use_fp16()
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16)
unet_dtype = comfy.model_management.unet_dtype()
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
@ -353,8 +353,8 @@ def load_controlnet(ckpt_path, model=None):
return net
if controlnet_config is None:
use_fp16 = comfy.model_management.should_use_fp16()
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16, True).unet_config
unet_dtype = comfy.model_management.unet_dtype()
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
@ -383,8 +383,7 @@ def load_controlnet(ckpt_path, model=None):
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
print(missing, unexpected)
if use_fp16:
control_model = control_model.half()
control_model = control_model.to(unet_dtype)
global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0]
@ -417,7 +416,7 @@ class T2IAdapter(ControlBase):
if control_prev is not None:
return control_prev
else:
return {}
return None
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:

View File

@ -31,6 +31,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
vae = None
if output_vae:
vae = comfy.sd.VAE(ckpt_path=vae_path)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
return (unet, clip, vae)

View File

@ -713,8 +713,8 @@ class UniPC:
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
):
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start
# t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
# t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device
steps = len(timesteps) - 1
if method == 'multistep':
@ -769,8 +769,8 @@ class UniPC:
callback(step_index, model_prev_list[-1], x, steps)
else:
raise NotImplementedError()
if denoise_to_zero:
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
# if denoise_to_zero:
# x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
return x
@ -833,21 +833,33 @@ def expand_dims(v, dims):
return v[(...,) + (None,)*(dims - 1)]
class SigmaConvert:
schedule = ""
def marginal_log_mean_coeff(self, sigma):
return 0.5 * torch.log(1 / ((sigma * sigma) + 1))
def marginal_alpha(self, t):
return torch.exp(self.marginal_log_mean_coeff(t))
def marginal_std(self, t):
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
def marginal_lambda(self, t):
"""
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
"""
log_mean_coeff = self.marginal_log_mean_coeff(t)
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
return log_mean_coeff - log_std
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
to_zero = False
timesteps = sigmas.clone()
if sigmas[-1] == 0:
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
to_zero = True
timesteps = sigmas[:]
timesteps[-1] = 0.001
else:
timesteps = sigmas.clone()
alphas_cumprod = model.inner_model.alphas_cumprod
for s in range(timesteps.shape[0]):
timesteps[s] = (model.sigma_to_discrete_timestep(timesteps[s]) / 1000) + (1 / len(alphas_cumprod))
ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
ns = SigmaConvert()
if image is not None:
img = image * ns.marginal_alpha(timesteps[0])
@ -859,16 +871,10 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
else:
img = noise
if to_zero:
timesteps[-1] = (1 / len(alphas_cumprod))
device = noise.device
model_type = "noise"
model_fn = model_wrapper(
model.predict_eps_discrete_timestep,
model.predict_eps_sigma,
ns,
model_type=model_type,
guidance_type="uncond",
@ -878,6 +884,5 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
order = min(3, len(timesteps) - 1)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
if not to_zero:
x /= ns.marginal_alpha(timesteps[-1])
x /= ns.marginal_alpha(timesteps[-1])
return x

View File

@ -97,6 +97,10 @@ class DiscreteSchedule(nn.Module):
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
def predict_eps_sigma(self, input, sigma, **kwargs):
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
noise)."""

View File

@ -20,7 +20,7 @@ class SD15(LatentFormat):
[-0.2829, 0.1762, 0.2721],
[-0.2120, -0.2616, -0.7177]
]
self.taesd_decoder_name = "taesd_decoder.pth"
self.taesd_decoder_name = "taesd_decoder"
class SDXL(LatentFormat):
def __init__(self):
@ -32,4 +32,4 @@ class SDXL(LatentFormat):
[ 0.0568, 0.1687, -0.0755],
[-0.3112, -0.2359, -0.2076]
]
self.taesd_decoder_name = "taesdxl_decoder.pth"
self.taesd_decoder_name = "taesdxl_decoder"

View File

@ -2,67 +2,66 @@ import torch
# import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
from comfy.ldm.modules.diffusionmodules.model import Encoder, Decoder
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from comfy.ldm.util import instantiate_from_config
from comfy.ldm.modules.ema import LitEma
# class AutoencoderKL(pl.LightningModule):
class AutoencoderKL(torch.nn.Module):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False
):
class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = True):
super().__init__()
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
self.sample = sample
def get_trainable_parameters(self) -> Any:
yield from ()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
log = dict()
posterior = DiagonalGaussianDistribution(z)
if self.sample:
z = posterior.sample()
else:
z = posterior.mode()
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log["kl_loss"] = kl_loss
return z, log
class AbstractAutoencoder(torch.nn.Module):
"""
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
unCLIP models, etc. Hence, it is fairly general, and specific features
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
"""
def __init__(
self,
ema_decay: Union[None, float] = None,
monitor: Union[None, str] = None,
input_key: str = "jpg",
**kwargs,
):
super().__init__()
self.input_key = input_key
self.use_ema = ema_decay is not None
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None
if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def get_input(self, batch) -> Any:
raise NotImplementedError()
def init_from_ckpt(self, path, ignore_keys=list()):
if path.lower().endswith(".safetensors"):
import safetensors.torch
sd = safetensors.torch.load_file(path, device="cpu")
else:
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
def on_train_batch_end(self, *args, **kwargs):
# for EMA computation
if self.use_ema:
self.model_ema(self)
@contextmanager
def ema_scope(self, context=None):
@ -70,154 +69,159 @@ class AutoencoderKL(torch.nn.Module):
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
logpy.info(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
logpy.info(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("encode()-method of abstract base class called")
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("decode()-method of abstract base class called")
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def instantiate_optimizer_from_config(self, params, lr, cfg):
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def configure_optimizers(self) -> Any:
raise NotImplementedError()
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
class AutoencodingEngine(AbstractAutoencoder):
"""
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
(we also restore them explicitly as special cases for legacy reasons).
Regularizations such as KL or VQ are moved to the regularizer class.
"""
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return aeloss
def __init__(
self,
*args,
encoder_config: Dict,
decoder_config: Dict,
regularizer_config: Dict,
**kwargs,
):
super().__init__(*args, **kwargs)
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, postfix=""):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
if self.learn_logvar:
print(f"{self.__class__.__name__}: Learning logvar")
ae_params_list.append(self.loss.logvar)
opt_ae = torch.optim.Adam(ae_params_list,
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
self.regularization: AbstractRegularizer = instantiate_from_config(
regularizer_config
)
def get_last_layer(self):
return self.decoder.conv_out.weight
return self.decoder.get_last_layer()
@torch.no_grad()
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
if log_ema or self.use_ema:
with self.ema_scope():
xrec_ema, posterior_ema = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
log["reconstructions_ema"] = xrec_ema
log["inputs"] = x
return log
def encode(
self,
x: torch.Tensor,
return_reg_log: bool = False,
unregularized: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
z = self.encoder(x)
if unregularized:
return z, dict()
z, reg_log = self.regularization(z)
if return_reg_log:
return z, reg_log
return z
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
x = self.decoder(z, **kwargs)
return x
def forward(
self, x: torch.Tensor, **additional_decode_kwargs
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True)
dec = self.decode(z, **additional_decode_kwargs)
return z, dec, reg_log
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs):
return x
class AutoencodingEngineLegacy(AutoencodingEngine):
def __init__(self, embed_dim: int, **kwargs):
self.max_batch_size = kwargs.pop("max_batch_size", None)
ddconfig = kwargs.pop("ddconfig")
super().__init__(
encoder_config={
"target": "comfy.ldm.modules.diffusionmodules.model.Encoder",
"params": ddconfig,
},
decoder_config={
"target": "comfy.ldm.modules.diffusionmodules.model.Decoder",
"params": ddconfig,
},
**kwargs,
)
self.quant_conv = torch.nn.Conv2d(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim,
1,
)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
def decode(self, x, *args, **kwargs):
return x
def get_autoencoder_params(self) -> list:
params = super().get_autoencoder_params()
return params
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def encode(
self, x: torch.Tensor, return_reg_log: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.max_batch_size is None:
z = self.encoder(x)
z = self.quant_conv(z)
else:
N = x.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
z = list()
for i_batch in range(n_batches):
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
z_batch = self.quant_conv(z_batch)
z.append(z_batch)
z = torch.cat(z, 0)
def forward(self, x, *args, **kwargs):
return x
z, reg_log = self.regularization(z)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
if self.max_batch_size is None:
dec = self.post_quant_conv(z)
dec = self.decoder(dec, **decoder_kwargs)
else:
N = z.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
dec = list()
for i_batch in range(n_batches):
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
dec.append(dec_batch)
dec = torch.cat(dec, 0)
return dec
class AutoencoderKL(AutoencodingEngineLegacy):
def __init__(self, **kwargs):
if "lossconfig" in kwargs:
kwargs["loss_config"] = kwargs.pop("lossconfig")
super().__init__(
regularizer_config={
"target": (
"comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"
)
},
**kwargs,
)

View File

@ -94,253 +94,256 @@ def zero_module(module):
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
h = heads
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
del q, k
# compute attention
b,c,h,w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
# attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x+h_
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return out
class CrossAttentionBirchSan(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
def attention_sub_quad(query, key, value, heads, mask=None):
b, _, dim_head = query.shape
dim_head //= heads
self.scale = dim_head ** -0.5
self.heads = heads
scale = dim_head ** -0.5
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout)
)
dtype = query.dtype
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
if upcast_attention:
bytes_per_token = torch.finfo(torch.float32).bits//8
else:
bytes_per_token = torch.finfo(query.dtype).bits//8
batch_x_heads, q_tokens, _ = query.shape
_, _, k_tokens = key.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
def forward(self, x, context=None, value=None, mask=None):
h = self.heads
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
query = self.to_q(x)
context = default(context, x)
key = self.to_k(context)
if value is not None:
value = self.to_v(value)
else:
value = self.to_v(context)
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
del context, x
kv_chunk_size_min = None
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
key_t = key.transpose(1,2).unflatten(1, (self.heads, -1)).flatten(end_dim=1)
del key
value = value.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
#not sure at all about the math here
#TODO: tweak this
if mem_free_total > 8192 * 1024 * 1024 * 1.3:
query_chunk_size_x = 1024 * 4
elif mem_free_total > 4096 * 1024 * 1024 * 1.3:
query_chunk_size_x = 1024 * 2
else:
query_chunk_size_x = 1024
kv_chunk_size_min_x = None
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
if kv_chunk_size_x < 1024:
kv_chunk_size_x = None
dtype = query.dtype
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
if upcast_attention:
bytes_per_token = torch.finfo(torch.float32).bits//8
else:
bytes_per_token = torch.finfo(query.dtype).bits//8
batch_x_heads, q_tokens, _ = query.shape
_, _, k_tokens = key_t.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
query_chunk_size = q_tokens
kv_chunk_size = k_tokens
else:
query_chunk_size = query_chunk_size_x
kv_chunk_size = kv_chunk_size_x
kv_chunk_size_min = kv_chunk_size_min_x
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
hidden_states = efficient_dot_product_attention(
query,
key,
value,
query_chunk_size=query_chunk_size,
kv_chunk_size=kv_chunk_size,
kv_chunk_size_min=kv_chunk_size_min,
use_checkpoint=False,
upcast_attention=upcast_attention,
)
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
hidden_states = hidden_states.to(dtype)
kv_chunk_size_min = None
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states
#not sure at all about the math here
#TODO: tweak this
if mem_free_total > 8192 * 1024 * 1024 * 1.3:
query_chunk_size_x = 1024 * 4
elif mem_free_total > 4096 * 1024 * 1024 * 1.3:
query_chunk_size_x = 1024 * 2
else:
query_chunk_size_x = 1024
kv_chunk_size_min_x = None
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
if kv_chunk_size_x < 1024:
kv_chunk_size_x = None
def attention_split(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
query_chunk_size = q_tokens
kv_chunk_size = k_tokens
else:
query_chunk_size = query_chunk_size_x
kv_chunk_size = kv_chunk_size_x
kv_chunk_size_min = kv_chunk_size_min_x
h = heads
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
hidden_states = efficient_dot_product_attention(
query,
key_t,
value,
query_chunk_size=query_chunk_size,
kv_chunk_size=kv_chunk_size,
kv_chunk_size_min=kv_chunk_size_min,
use_checkpoint=self.training,
upcast_attention=upcast_attention,
)
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
hidden_states = hidden_states.to(dtype)
mem_free_total = model_management.get_free_memory(q.device)
hidden_states = hidden_states.unflatten(0, (-1, self.heads)).transpose(1,2).flatten(start_dim=2)
out_proj, dropout = self.to_out
hidden_states = out_proj(hidden_states)
hidden_states = dropout(hidden_states)
return hidden_states
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
class CrossAttentionDoggettx(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
self.scale = dim_head ** -0.5
self.heads = heads
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout)
)
def forward(self, x, context=None, value=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context)
if value is not None:
v_in = self.to_v(value)
del value
else:
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
mem_free_total = model_management.get_free_memory(q.device)
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
first_op_done = False
cleared_cache = False
while True:
try:
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * self.scale
else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
first_op_done = True
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
break
except model_management.OOM_EXCEPTION as e:
if first_op_done == False:
model_management.soft_empty_cache(True)
if cleared_cache == False:
cleared_cache = True
print("out of memory error, emptying cache and trying again")
continue
steps *= 2
if steps > 64:
raise e
print("out of memory error, increasing steps and trying again", steps)
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
first_op_done = False
cleared_cache = False
while True:
try:
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
first_op_done = True
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
break
except model_management.OOM_EXCEPTION as e:
if first_op_done == False:
model_management.soft_empty_cache(True)
if cleared_cache == False:
cleared_cache = True
print("out of memory error, emptying cache and trying again")
continue
steps *= 2
if steps > 64:
raise e
print("out of memory error, increasing steps and trying again", steps)
else:
raise e
del q, k, v
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
r1 = (
r1.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return r1
return self.to_out(r2)
def attention_xformers(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return out
def attention_pytorch(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
return out
optimized_attention = attention_basic
optimized_attention_masked = attention_basic
if model_management.xformers_enabled():
print("Using xformers cross attention")
optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled():
print("Using pytorch cross attention")
optimized_attention = attention_pytorch
else:
if args.use_split_cross_attention:
print("Using split optimization for cross attention")
optimized_attention = attention_split
else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad
if model_management.pytorch_attention_enabled():
optimized_attention_masked = attention_pytorch
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
@ -348,62 +351,6 @@ class CrossAttention(nn.Module):
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout)
)
def forward(self, x, context=None, value=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
if value is not None:
v = self.to_v(value)
del value
else:
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=comfy.ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
@ -412,7 +359,6 @@ class MemoryEfficientCrossAttention(nn.Module):
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None):
q = self.to_q(x)
@ -424,85 +370,12 @@ class MemoryEfficientCrossAttention(nn.Module):
else:
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)
class CrossAttentionPytorch(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
if value is not None:
v = self.to_v(value)
del value
if mask is None:
out = optimized_attention(q, k, v, self.heads)
else:
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
if exists(mask):
raise NotImplementedError
out = (
out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head)
)
out = optimized_attention_masked(q, k, v, self.heads, mask)
return self.to_out(out)
if model_management.xformers_enabled():
print("Using xformers cross attention")
CrossAttention = MemoryEfficientCrossAttention
elif model_management.pytorch_attention_enabled():
print("Using pytorch cross attention")
CrossAttention = CrossAttentionPytorch
else:
if args.use_split_cross_attention:
print("Using split optimization for cross attention")
CrossAttention = CrossAttentionDoggettx
else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
CrossAttention = CrossAttentionBirchSan
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,

View File

@ -6,7 +6,6 @@ import numpy as np
from einops import rearrange
from typing import Optional, Any
from ..attention import MemoryEfficientCrossAttention
from comfy import model_management
import comfy.ops
@ -194,6 +193,52 @@ def slice_attention(q, k, v):
return r1
def normal_attention(q, k, v):
# compute attention
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
v = v.reshape(b,c,h*w)
r1 = slice_attention(q, k, v)
h_ = r1.reshape(b,c,h,w)
del r1
return h_
def xformers_attention(q, k, v):
# compute attention
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
(q, k, v),
)
try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
out = out.transpose(1, 2).reshape(B, C, H, W)
except NotImplementedError as e:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
return out
def pytorch_attention(q, k, v):
# compute attention
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
)
try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W)
except model_management.OOM_EXCEPTION as e:
print("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
return out
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
@ -221,6 +266,16 @@ class AttnBlock(nn.Module):
stride=1,
padding=0)
if model_management.xformers_enabled_vae():
print("Using xformers attention in VAE")
self.optimized_attention = xformers_attention
elif model_management.pytorch_attention_enabled():
print("Using pytorch attention in VAE")
self.optimized_attention = pytorch_attention
else:
print("Using split attention in VAE")
self.optimized_attention = normal_attention
def forward(self, x):
h_ = x
h_ = self.norm(h_)
@ -228,161 +283,15 @@ class AttnBlock(nn.Module):
k = self.k(h_)
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
h_ = self.optimized_attention(q, k, v)
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
v = v.reshape(b,c,h*w)
r1 = slice_attention(q, k, v)
h_ = r1.reshape(b,c,h,w)
del r1
h_ = self.proj_out(h_)
return x+h_
class MemoryEfficientAttnBlock(nn.Module):
"""
Uses xformers efficient implementation,
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
Note: this is a single-head self-attention operation
"""
#
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.attention_op: Optional[Any] = None
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
(q, k, v),
)
try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
out = out.transpose(1, 2).reshape(B, C, H, W)
except NotImplementedError as e:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = self.proj_out(out)
return x+out
class MemoryEfficientAttnBlockPytorch(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.attention_op: Optional[Any] = None
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
)
try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W)
except model_management.OOM_EXCEPTION as e:
print("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = self.proj_out(out)
return x+out
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def forward(self, x, context=None, mask=None):
b, c, h, w = x.shape
x = rearrange(x, 'b c h w -> b (h w) c')
out = super().forward(x, context=context, mask=mask)
out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
return x + out
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
if model_management.xformers_enabled_vae() and attn_type == "vanilla":
attn_type = "vanilla-xformers"
if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
attn_type = "vanilla-pytorch"
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
assert attn_kwargs is None
return AttnBlock(in_channels)
elif attn_type == "vanilla-xformers":
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
return MemoryEfficientAttnBlock(in_channels)
elif attn_type == "vanilla-pytorch":
return MemoryEfficientAttnBlockPytorch(in_channels)
elif type == "memory-efficient-cross-attn":
attn_kwargs["query_dim"] = in_channels
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
elif attn_type == "none":
return nn.Identity(in_channels)
else:
raise NotImplementedError()
return AttnBlock(in_channels)
class Model(nn.Module):
@ -632,7 +541,10 @@ class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
attn_type="vanilla", **ignorekwargs):
conv_out_op=comfy.ops.Conv2d,
resnet_op=ResnetBlock,
attn_op=AttnBlock,
**ignorekwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
@ -661,12 +573,12 @@ class Decoder(nn.Module):
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
self.mid.block_1 = resnet_op(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
self.mid.attn_1 = attn_op(block_in)
self.mid.block_2 = resnet_op(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
@ -678,13 +590,13 @@ class Decoder(nn.Module):
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
block.append(ResnetBlock(in_channels=block_in,
block.append(resnet_op(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
attn.append(attn_op(block_in))
up = nn.Module()
up.block = block
up.attn = attn
@ -695,13 +607,13 @@ class Decoder(nn.Module):
# end
self.norm_out = Normalize(block_in)
self.conv_out = comfy.ops.Conv2d(block_in,
self.conv_out = conv_out_op(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, z):
def forward(self, z, **kwargs):
#assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
@ -712,16 +624,16 @@ class Decoder(nn.Module):
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
h = self.mid.block_1(h, temb, **kwargs)
h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, temb, **kwargs)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb)
h = self.up[i_level].block[i_block](h, temb, **kwargs)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
h = self.up[i_level].attn[i_block](h, **kwargs)
if i_level != 0:
h = self.up[i_level].upsample(h)
@ -731,7 +643,7 @@ class Decoder(nn.Module):
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
h = self.conv_out(h, **kwargs)
if self.tanh_out:
h = torch.tanh(h)
return h

View File

@ -296,8 +296,7 @@ class UNetModel(nn.Module):
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
use_bf16=False,
dtype=th.float32,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
@ -370,8 +369,7 @@ class UNetModel(nn.Module):
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.dtype = th.bfloat16 if use_bf16 else self.dtype
self.dtype = dtype
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample

View File

@ -26,6 +26,7 @@ class BaseModel(torch.nn.Module):
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
self.inpaint_model = False
print("model_type", model_type.name)
print("adm", self.adm_channels)
@ -71,6 +72,38 @@ class BaseModel(torch.nn.Module):
def encode_adm(self, **kwargs):
return None
def cond_concat(self, **kwargs):
if self.inpaint_model:
concat_keys = ("mask", "masked_image")
cond_concat = []
denoise_mask = kwargs.get("denoise_mask", None)
latent_image = kwargs.get("latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
return blank_image
for ck in concat_keys:
if denoise_mask is not None:
if ck == "mask":
cond_concat.append(denoise_mask[:,:1].to(device))
elif ck == "masked_image":
cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise))
return cond_concat
return None
def load_model_weights(self, sd, unet_prefix=""):
to_load = {}
keys = list(sd.keys())
@ -112,7 +145,7 @@ class BaseModel(torch.nn.Module):
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
def set_inpaint(self):
self.concat_keys = ("mask", "masked_image")
self.inpaint_model = True
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
adm_inputs = []

View File

@ -14,7 +14,7 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1
return count
def detect_unet_config(state_dict, key_prefix, use_fp16):
def detect_unet_config(state_dict, key_prefix, dtype):
state_dict_keys = list(state_dict.keys())
unet_config = {
@ -32,7 +32,7 @@ def detect_unet_config(state_dict, key_prefix, use_fp16):
else:
unet_config["adm_in_channels"] = None
unet_config["use_fp16"] = use_fp16
unet_config["dtype"] = dtype
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
@ -116,15 +116,15 @@ def model_config_from_unet_config(unet_config):
print("no match", unet_config)
return None
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
model_config = model_config_from_unet_config(unet_config)
if model_config is None and use_base_if_no_match:
return comfy.supported_models_base.BASE(unet_config)
else:
return model_config
def unet_config_from_diffusers_unet(state_dict, use_fp16):
def unet_config_from_diffusers_unet(state_dict, dtype):
match = {}
attention_resolutions = []
@ -147,47 +147,47 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384,
'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64}
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8}
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 9, 'model_channels': 320,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
@ -203,8 +203,8 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
return unet_config
return None
def model_config_from_diffusers_unet(state_dict, use_fp16):
unet_config = unet_config_from_diffusers_unet(state_dict, use_fp16)
def model_config_from_diffusers_unet(state_dict, dtype):
unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
if unet_config is not None:
return model_config_from_unet_config(unet_config)
return None

View File

@ -154,14 +154,18 @@ def is_nvidia():
return True
return False
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention:
ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILABLE = False
VAE_DTYPE = torch.float32
try:
if is_nvidia():
torch_version = torch.version.__version__
if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if torch.cuda.is_bf16_supported():
VAE_DTYPE = torch.bfloat16
@ -186,7 +190,6 @@ if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
XFORMERS_IS_AVAILABLE = False
if args.lowvram:
set_vram_to = VRAMState.LOW_VRAM
@ -336,7 +339,11 @@ def free_memory(memory_required, device, keep_loaded=[]):
if unloaded_model:
soft_empty_cache()
else:
if vram_state != VRAMState.HIGH_VRAM:
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache()
def load_models_gpu(models, memory_required=0):
global vram_state
@ -354,6 +361,8 @@ def load_models_gpu(models, memory_required=0):
current_loaded_models.insert(0, current_loaded_models.pop(index))
models_already_loaded.append(loaded_model)
else:
if hasattr(x, "model"):
print(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model)
if len(models_to_load) == 0:
@ -363,7 +372,7 @@ def load_models_gpu(models, memory_required=0):
free_memory(extra_mem, d, models_already_loaded)
return
print("loading new")
print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
total_memory_required = {}
for loaded_model in models_to_load:
@ -405,7 +414,6 @@ def load_model_gpu(model):
def cleanup_models():
to_delete = []
for i in range(len(current_loaded_models)):
print(sys.getrefcount(current_loaded_models[i].model))
if sys.getrefcount(current_loaded_models[i].model) <= 2:
to_delete = [i] + to_delete
@ -444,6 +452,13 @@ def unet_inital_load_device(parameters, dtype):
else:
return cpu_dev
def unet_dtype(device=None, model_params=0):
if args.bf16_unet:
return torch.bfloat16
if should_use_fp16(device=device, model_params=model_params):
return torch.float16
return torch.float32
def text_encoder_offload_device():
if args.gpu_only:
return get_torch_device()
@ -656,7 +671,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
return False
#FP16 is just broken on these cards
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX"]
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"]
for x in nvidia_16_series:
if x in props.name:
return False

View File

@ -107,6 +107,10 @@ class ModelPatcher:
for k in patch_list:
if hasattr(patch_list[k], "to"):
patch_list[k] = patch_list[k].to(device)
if "unet_wrapper_function" in self.model_options:
wrap_func = self.model_options["unet_wrapper_function"]
if hasattr(wrap_func, "to"):
self.model_options["unet_wrapper_function"] = wrap_func.to(device)
def model_dtype(self):
if hasattr(self.model, "get_dtype"):

View File

@ -98,6 +98,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
samples = samples.cpu()
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
@ -109,5 +110,6 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent
samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.cpu()
cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
return samples

View File

@ -14,8 +14,8 @@ def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
#The main sampling function shared by all the samplers
#Returns predicted noise
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}, seed=None):
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def get_area_and_mult(cond, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0
if 'timestep_start' in cond[1]:
@ -68,12 +68,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
conditionning = {}
conditionning['c_crossattn'] = cond[0]
if cond_concat_in is not None and len(cond_concat_in) > 0:
cropped = []
for x in cond_concat_in:
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
cropped.append(cr)
conditionning['c_concat'] = torch.cat(cropped, dim=1)
if 'concat' in cond[1]:
cond_concat_in = cond[1]['concat']
if cond_concat_in is not None and len(cond_concat_in) > 0:
cropped = []
for x in cond_concat_in:
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
cropped.append(cr)
conditionning['c_concat'] = torch.cat(cropped, dim=1)
if adm_cond is not None:
conditionning['c_adm'] = adm_cond
@ -173,7 +176,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
out['c_adm'] = torch.cat(c_adm)
return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options):
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in)/100000.0
@ -185,14 +188,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
to_run = []
for x in cond:
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, COND)]
if uncond is not None:
for x in uncond:
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
@ -286,7 +289,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if math.isclose(cond_scale, 1.0):
uncond = None
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options)
if "sampler_cfg_function" in model_options:
args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep}
return model_options["sampler_cfg_function"](args)
@ -307,8 +310,8 @@ class CFGNoisePredictor(torch.nn.Module):
super().__init__()
self.inner_model = model
self.alphas_cumprod = model.alphas_cumprod
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, seed=None):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options, seed=seed)
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
return out
@ -316,11 +319,11 @@ class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}, seed=None):
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
if denoise_mask is not None:
latent_mask = 1. - denoise_mask
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options, seed=seed)
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
if denoise_mask is not None:
out *= denoise_mask
@ -358,15 +361,6 @@ def sgm_scheduler(model, steps):
sigs += [0.0]
return torch.FloatTensor(sigs)
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
return blank_image
def get_mask_aabb(masks):
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
@ -543,6 +537,20 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
return conds
def encode_cond(model_function, key, conds, device, **kwargs):
for t in range(len(conds)):
x = conds[t]
params = x[1].copy()
params["device"] = device
for k in kwargs:
if k not in params:
params[k] = kwargs[k]
out = model_function(**params)
if out is not None:
x[1] = x[1].copy()
x[1][key] = out
return conds
class Sampler:
def sample(self):
@ -662,31 +670,19 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if latent_image is not None:
latent_image = model.process_latent_in(latent_image)
if model.is_adm():
positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
if latent_image is not None:
latent_image = model.process_latent_in(latent_image)
if hasattr(model, 'cond_concat'):
positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
cond_concat = None
if hasattr(model, 'concat_keys'): #inpaint
cond_concat = []
for ck in model.concat_keys:
if denoise_mask is not None:
if ck == "mask":
cond_concat.append(denoise_mask[:,:1])
elif ck == "masked_image":
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise))
extra_args["cond_concat"] = cond_concat
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return model.process_latent_out(samples.to(torch.float32))
@ -743,7 +739,7 @@ class KSampler:
sigmas = None
discard_penultimate_sigma = False
if self.sampler in ['dpm_2', 'dpm_2_ancestral']:
if self.sampler in ['dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2']:
steps += 1
discard_penultimate_sigma = True

View File

@ -4,7 +4,7 @@ import math
from comfy import model_management
from .ldm.util import instantiate_from_config
from .ldm.models.autoencoder import AutoencoderKL
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
import yaml
import comfy.utils
@ -140,21 +140,24 @@ class CLIP:
return self.patcher.get_key_patches()
class VAE:
def __init__(self, ckpt_path=None, device=None, config=None):
def __init__(self, sd=None, device=None, config=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
if config is None:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss")
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
if ckpt_path is not None:
sd = comfy.utils.load_torch_file(ckpt_path)
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if len(m) > 0:
print("Missing VAE keys", m)
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if len(m) > 0:
print("Missing VAE keys", m)
if len(u) > 0:
print("Leftover VAE keys", u)
if device is None:
device = model_management.vae_device()
@ -183,7 +186,7 @@ class VAE:
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).sample().float()
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
@ -229,7 +232,7 @@ class VAE:
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu().float()
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
@ -327,7 +330,9 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if "params" in model_config_params["unet_config"]:
unet_config = model_config_params["unet_config"]["params"]
if "use_fp16" in unet_config:
fp16 = unet_config["use_fp16"]
fp16 = unet_config.pop("use_fp16")
if fp16:
unet_config["dtype"] = torch.float16
noise_aug_config = None
if "noise_aug_config" in model_config_params:
@ -373,10 +378,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
model.load_model_weights(state_dict, "model.diffusion_model.")
if output_vae:
w = WeightsLoader()
vae = VAE(config=vae_config)
w.first_stage_model = vae.first_stage_model
load_model_weights(w, state_dict)
vae_sd = comfy.utils.state_dict_prefix_replace(state_dict, {"first_stage_model.": ""}, filter_keys=True)
vae = VAE(sd=vae_sd, config=vae_config)
if output_clip:
w = WeightsLoader()
@ -405,12 +408,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clip_target = None
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
fp16 = model_management.should_use_fp16(model_params=parameters)
unet_dtype = model_management.unet_dtype(model_params=parameters)
class WeightsLoader(torch.nn.Module):
pass
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", fp16)
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
@ -418,29 +421,24 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_clipvision:
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
dtype = torch.float32
if fp16:
dtype = torch.float16
if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, dtype)
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
model.load_model_weights(sd, "model.diffusion_model.")
if output_vae:
vae = VAE()
w = WeightsLoader()
w.first_stage_model = vae.first_stage_model
load_model_weights(w, sd)
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae = VAE(sd=vae_sd)
if output_clip:
w = WeightsLoader()
clip_target = model_config.clip_target()
clip = CLIP(clip_target, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
sd = model_config.process_clip_state_dict(sd)
load_model_weights(w, sd)
if clip_target is not None:
clip = CLIP(clip_target, embedding_directory=embedding_directory)
w.cond_stage_model = clip.cond_stage_model
sd = model_config.process_clip_state_dict(sd)
load_model_weights(w, sd)
left_over = sd.keys()
if len(left_over) > 0:
@ -458,15 +456,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def load_unet(unet_path): #load unet in diffusers format
sd = comfy.utils.load_torch_file(unet_path)
parameters = comfy.utils.calculate_parameters(sd)
fp16 = model_management.should_use_fp16(model_params=parameters)
unet_dtype = model_management.unet_dtype(model_params=parameters)
if "input_blocks.0.0.weight" in sd: #ldm
model_config = model_detection.model_config_from_unet(sd, "", fp16)
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
new_sd = sd
else: #diffusers
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
if model_config is None:
print("ERROR UNSUPPORTED UNET", unet_path)
return None

View File

@ -6,6 +6,8 @@ Tiny AutoEncoder for Stable Diffusion
import torch
import torch.nn as nn
import comfy.utils
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
@ -50,9 +52,9 @@ class TAESD(nn.Module):
self.encoder = Encoder()
self.decoder = Decoder()
if encoder_path is not None:
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
if decoder_path is not None:
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
@staticmethod
def scale_latents(x):

View File

@ -47,12 +47,17 @@ def state_dict_key_replace(state_dict, keys_to_replace):
state_dict[keys_to_replace[x]] = state_dict.pop(x)
return state_dict
def state_dict_prefix_replace(state_dict, replace_prefix):
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
if filter_keys:
out = {}
else:
out = state_dict
for rp in replace_prefix:
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
for x in replace:
state_dict[x[1]] = state_dict.pop(x[0])
return state_dict
w = state_dict.pop(x[0])
out[x[1]] = w
return out
def transformers_convert(sd, prefix_from, prefix_to, number):
@ -408,6 +413,10 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
output[b:b+1] = out/out_div
return output
PROGRESS_BAR_ENABLED = True
def set_progress_bar_enabled(enabled):
global PROGRESS_BAR_ENABLED
PROGRESS_BAR_ENABLED = enabled
PROGRESS_BAR_HOOK = None
def set_progress_bar_global_hook(function):

View File

@ -158,7 +158,7 @@ class SplitImageWithAlpha:
def split_image_with_alpha(self, image: torch.Tensor):
out_images = [i[:,:,:3] for i in image]
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
result = (torch.stack(out_images), torch.stack(out_alphas))
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
return result
@ -180,7 +180,7 @@ class JoinImageWithAlpha:
batch_size = min(len(image), len(alpha))
out_images = []
alpha = resize_mask(alpha, image.shape[1:])
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))

View File

@ -3,6 +3,7 @@ import comfy.sample
from comfy.k_diffusion import sampling as k_diffusion_sampling
import latent_preview
import torch
import comfy.utils
class BasicScheduler:
@ -15,7 +16,7 @@ class BasicScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "_for_testing/custom_sampling"
CATEGORY = "sampling/custom_sampling"
FUNCTION = "get_sigmas"
@ -35,7 +36,7 @@ class KarrasScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "_for_testing/custom_sampling"
CATEGORY = "sampling/custom_sampling"
FUNCTION = "get_sigmas"
@ -53,7 +54,7 @@ class ExponentialScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "_for_testing/custom_sampling"
CATEGORY = "sampling/custom_sampling"
FUNCTION = "get_sigmas"
@ -72,7 +73,7 @@ class PolyexponentialScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "_for_testing/custom_sampling"
CATEGORY = "sampling/custom_sampling"
FUNCTION = "get_sigmas"
@ -91,7 +92,7 @@ class VPScheduler:
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "_for_testing/custom_sampling"
CATEGORY = "sampling/custom_sampling"
FUNCTION = "get_sigmas"
@ -108,7 +109,7 @@ class SplitSigmas:
}
}
RETURN_TYPES = ("SIGMAS","SIGMAS")
CATEGORY = "_for_testing/custom_sampling"
CATEGORY = "sampling/custom_sampling"
FUNCTION = "get_sigmas"
@ -125,7 +126,7 @@ class KSamplerSelect:
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "_for_testing/custom_sampling"
CATEGORY = "sampling/custom_sampling"
FUNCTION = "get_sampler"
@ -144,7 +145,7 @@ class SamplerDPMPP_2M_SDE:
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "_for_testing/custom_sampling"
CATEGORY = "sampling/custom_sampling"
FUNCTION = "get_sampler"
@ -168,7 +169,7 @@ class SamplerDPMPP_SDE:
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "_for_testing/custom_sampling"
CATEGORY = "sampling/custom_sampling"
FUNCTION = "get_sampler"
@ -201,7 +202,7 @@ class SamplerCustom:
FUNCTION = "sample"
CATEGORY = "_for_testing/custom_sampling"
CATEGORY = "sampling/custom_sampling"
def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image):
latent = latent_image
@ -219,7 +220,7 @@ class SamplerCustom:
x0_output = {}
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
disable_pbar = False
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
out = latent.copy()

View File

@ -61,7 +61,53 @@ class FreeU:
m.set_model_output_block_patch(output_block_patch)
return (m, )
class FreeU_V2:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}),
"b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}),
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"]
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
on_cpu_devices = {}
def output_block_patch(h, hsp, transformer_options):
scale = scale_dict.get(h.shape[1], None)
if scale is not None:
hidden_mean = h.mean(1).unsqueeze(1)
B = hidden_mean.shape[0]
hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * ((scale[0] - 1 ) * hidden_mean + 1)
if hsp.device not in on_cpu_devices:
try:
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
except:
print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.")
on_cpu_devices[hsp.device] = True
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
else:
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
return h, hsp
m = model.clone()
m.set_model_output_block_patch(output_block_patch)
return (m, )
NODE_CLASS_MAPPINGS = {
"FreeU": FreeU,
"FreeU_V2": FreeU_V2,
}

View File

@ -19,6 +19,7 @@ def load_hypernetwork_patch(path, strength):
"tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid,
"softsign": torch.nn.Softsign,
"mish": torch.nn.Mish,
}
if activation_func not in valid_activation:
@ -42,7 +43,8 @@ def load_hypernetwork_patch(path, strength):
linears = list(map(lambda a: a[:-len(".weight")], linears))
layers = []
for i in range(len(linears)):
i = 0
while i < len(linears):
lin_name = linears[i]
last_layer = (i == (len(linears) - 1))
penultimate_layer = (i == (len(linears) - 2))
@ -56,10 +58,17 @@ def load_hypernetwork_patch(path, strength):
if (not last_layer) or (activate_output):
layers.append(valid_activation[activation_func]())
if is_layer_norm:
layers.append(torch.nn.LayerNorm(lin_weight.shape[0]))
i += 1
ln_name = linears[i]
ln_weight = attn_weights['{}.weight'.format(ln_name)]
ln_bias = attn_weights['{}.bias'.format(ln_name)]
ln = torch.nn.LayerNorm(ln_weight.shape[0])
ln.load_state_dict({"weight": ln_weight, "bias": ln_bias})
layers.append(ln)
if use_dropout:
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
layers.append(torch.nn.Dropout(p=0.3))
i += 1
output.append(torch.nn.Sequential(*layers))
out[dim] = torch.nn.ModuleList(output)

View File

@ -0,0 +1,83 @@
#Taken from: https://github.com/tfernd/HyperTile/
import math
from einops import rearrange
import random
def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int:
min_value = min(min_value, value)
# All big divisors of value (inclusive)
divisors = [i for i in range(min_value, value + 1) if value % i == 0]
ns = [value // i for i in divisors[:max_options]] # has at least 1 element
random.seed(counter)
idx = random.randint(0, len(ns) - 1)
return ns[idx]
class HyperTile:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}),
"swap_size": ("INT", {"default": 2, "min": 1, "max": 128}),
"max_depth": ("INT", {"default": 0, "min": 0, "max": 10}),
"scale_depth": ("BOOLEAN", {"default": False}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
model_channels = model.model.model_config.unet_config["model_channels"]
apply_to = set()
temp = model_channels
for x in range(max_depth + 1):
apply_to.add(temp)
temp *= 2
latent_tile_size = max(32, tile_size) // 8
self.temp = None
self.counter = 1
def hypertile_in(q, k, v, extra_options):
if q.shape[-1] in apply_to:
shape = extra_options["original_shape"]
aspect_ratio = shape[-1] / shape[-2]
hw = q.size(1)
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1
nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter)
self.counter += 1
nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter)
self.counter += 1
if nh * nw > 1:
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
self.temp = (nh, nw, h, w)
return q, k, v
return q, k, v
def hypertile_out(out, extra_options):
if self.temp is not None:
nh, nw, h, w = self.temp
self.temp = None
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
return out
m = model.clone()
m.set_model_attn1_patch(hypertile_in)
m.set_model_attn1_output_patch(hypertile_out)
return (m, )
NODE_CLASS_MAPPINGS = {
"HyperTile": HyperTile,
}

View File

@ -240,8 +240,8 @@ class MaskComposite:
right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2]))
visible_width, visible_height = (right - left, bottom - top,)
source_portion = source[:visible_height, :visible_width]
destination_portion = destination[top:bottom, left:right]
source_portion = source[:, :visible_height, :visible_width]
destination_portion = destination[:, top:bottom, left:right]
if operation == "multiply":
output[:, top:bottom, left:right] = destination_portion * source_portion
@ -282,10 +282,10 @@ class FeatherMask:
def feather(self, mask, left, top, right, bottom):
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
left = min(left, output.shape[1])
right = min(right, output.shape[1])
top = min(top, output.shape[0])
bottom = min(bottom, output.shape[0])
left = min(left, output.shape[-1])
right = min(right, output.shape[-1])
top = min(top, output.shape[-2])
bottom = min(bottom, output.shape[-2])
for x in range(left):
feather_rate = (x + 1.0) / left

View File

@ -1,6 +1,7 @@
import comfy.sd
import comfy.utils
import comfy.model_base
import comfy.model_management
import folder_paths
import json
@ -178,6 +179,95 @@ class CheckpointSave:
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
return {}
class CLIPSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip": ("CLIP",),
"filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None):
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {}
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
comfy.model_management.load_models_gpu([clip.load_model()])
clip_sd = clip.get_sd()
for prefix in ["clip_l.", "clip_g.", ""]:
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
current_clip_sd = {}
for x in k:
current_clip_sd[x] = clip_sd.pop(x)
if len(current_clip_sd) == 0:
continue
p = prefix[:-1]
replace_prefix = {}
filename_prefix_ = filename_prefix
if len(p) > 0:
filename_prefix_ = "{}_{}".format(filename_prefix_, p)
replace_prefix[prefix] = ""
replace_prefix["transformer."] = ""
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir)
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
return {}
class VAESave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {}
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
return {}
NODE_CLASS_MAPPINGS = {
"ModelMergeSimple": ModelMergeSimple,
@ -186,4 +276,6 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeAdd": ModelAdd,
"CheckpointSave": CheckpointSave,
"CLIPMergeSimple": CLIPMergeSimple,
"CLIPSave": CLIPSave,
"VAESave": VAESave,
}

View File

@ -2,6 +2,7 @@ import os
import sys
import copy
import json
import logging
import threading
import heapq
import traceback
@ -156,7 +157,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
except comfy.model_management.InterruptProcessingException as iex:
print("Processing interrupted")
logging.info("Processing interrupted")
# skip formatting inputs/outputs
error_details = {
@ -177,8 +178,8 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
for node_id, node_outputs in outputs.items():
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
print("!!! Exception during processing !!!")
print(traceback.format_exc())
logging.error("!!! Exception during processing !!!")
logging.error(traceback.format_exc())
error_details = {
"node_id": unique_id,
@ -636,11 +637,11 @@ def validate_prompt(prompt):
if valid is True:
good_outputs.add(o)
else:
print(f"Failed to validate prompt for output {o}:")
logging.error(f"Failed to validate prompt for output {o}:")
if len(reasons) > 0:
print("* (prompt):")
logging.error("* (prompt):")
for reason in reasons:
print(f" - {reason['message']}: {reason['details']}")
logging.error(f" - {reason['message']}: {reason['details']}")
errors += [(o, reasons)]
for node_id, result in validated.items():
valid = result[0]
@ -656,11 +657,11 @@ def validate_prompt(prompt):
"dependent_outputs": [],
"class_type": class_type
}
print(f"* {class_type} {node_id}:")
logging.error(f"* {class_type} {node_id}:")
for reason in reasons:
print(f" - {reason['message']}: {reason['details']}")
logging.error(f" - {reason['message']}: {reason['details']}")
node_errors[node_id]["dependent_outputs"].append(o)
print("Output will be ignored")
logging.error("Output will be ignored")
if len(good_outputs) == 0:
errors_list = []

View File

@ -1,5 +1,6 @@
#Rename this to extra_model_paths.yaml and ComfyUI will load it
#config for a1111 ui
#all you have to do is change the base_path to where yours is installed
a111:
@ -19,6 +20,21 @@ a111:
hypernetworks: models/hypernetworks
controlnet: models/ControlNet
#config for comfyui
#your base path should be either an existing comfy install or a central folder where you store all of your models, loras, etc.
#comfyui:
# base_path: path/to/comfyui/
# checkpoints: models/checkpoints/
# clip: models/clip/
# clip_vision: models/clip_vision/
# configs: models/configs/
# controlnet: models/controlnet/
# embeddings: models/embeddings/
# loras: models/loras/
# upscale_models: models/upscale_models/
# vae: models/vae/
#other_ui:
# base_path: path/to/ui
# checkpoints: models/checkpoints

View File

@ -31,6 +31,7 @@ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
folder_names_and_paths["subflows"] = ([os.path.join(base_path, "subflows")], supported_subflow_extensions)
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
@ -53,6 +54,10 @@ def set_temp_directory(temp_dir):
global temp_directory
temp_directory = temp_dir
def set_input_directory(input_dir):
global input_directory
input_directory = input_dir
def get_output_directory():
global output_directory
return output_directory
@ -155,7 +160,7 @@ def recursive_search(directory, excluded_dir_names=None):
return result, dirs
def filter_files_extensions(files, extensions):
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))

View File

@ -56,7 +56,12 @@ def get_previewer(device, latent_format):
# TODO previewer methods
taesd_decoder_path = None
if latent_format.taesd_decoder_name is not None:
taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name)
taesd_decoder_path = next(
(fn for fn in folder_paths.get_filename_list("vae_approx")
if fn.startswith(latent_format.taesd_decoder_name)),
""
)
taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path)
if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB

10
main.py
View File

@ -175,6 +175,16 @@ if __name__ == "__main__":
print(f"Setting output directory to: {output_dir}")
folder_paths.set_output_directory(output_dir)
#These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
if args.input_directory:
input_dir = os.path.abspath(args.input_directory)
print(f"Setting input directory to: {input_dir}")
folder_paths.set_input_directory(input_dir)
if args.quick_test_for_ci:
exit(0)

View File

@ -584,7 +584,8 @@ class VAELoader:
#TODO: scale factor?
def load_vae(self, vae_name):
vae_path = folder_paths.get_full_path("vae", vae_name)
vae = comfy.sd.VAE(ckpt_path=vae_path)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
return (vae,)
class ControlNetLoader:
@ -1202,7 +1203,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
noise_mask = latent["noise_mask"]
callback = latent_preview.prepare_callback(model, steps)
disable_pbar = False
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
@ -1660,7 +1661,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"KSampler": "KSampler",
"KSamplerAdvanced": "KSampler (Advanced)",
# Loaders
"CheckpointLoader": "Load Checkpoint (With Config)",
"CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
"CheckpointLoaderSimple": "Load Checkpoint",
"VAELoader": "Load VAE",
"LoraLoader": "Load LoRA",
@ -1797,6 +1798,7 @@ def init_custom_nodes():
"nodes_freelunch.py",
"nodes_custom_sampler.py",
"nodes_subflow.py",
"nodes_hypertile.py",
]
for node_file in extras_files:

View File

@ -47,7 +47,7 @@
" !git pull\n",
"\n",
"!echo -= Install dependencies =-\n",
"!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117"
"!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117"
]
},
{

1
tests-ui/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
node_modules

View File

@ -0,0 +1,3 @@
{
"presets": ["@babel/preset-env"]
}

14
tests-ui/globalSetup.js Normal file
View File

@ -0,0 +1,14 @@
module.exports = async function () {
global.ResizeObserver = class ResizeObserver {
observe() {}
unobserve() {}
disconnect() {}
};
const { nop } = require("./utils/nopProxy");
global.enableWebGLCanvas = nop;
HTMLCanvasElement.prototype.getContext = nop;
localStorage["Comfy.Settings.Comfy.Logging.Enabled"] = "false";
};

9
tests-ui/jest.config.js Normal file
View File

@ -0,0 +1,9 @@
/** @type {import('jest').Config} */
const config = {
testEnvironment: "jsdom",
setupFiles: ["./globalSetup.js"],
clearMocks: true,
resetModules: true,
};
module.exports = config;

5566
tests-ui/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

30
tests-ui/package.json Normal file
View File

@ -0,0 +1,30 @@
{
"name": "comfui-tests",
"version": "1.0.0",
"description": "UI tests",
"main": "index.js",
"scripts": {
"test": "jest",
"test:generate": "node setup.js"
},
"repository": {
"type": "git",
"url": "git+https://github.com/comfyanonymous/ComfyUI.git"
},
"keywords": [
"comfyui",
"test"
],
"author": "comfyanonymous",
"license": "GPL-3.0",
"bugs": {
"url": "https://github.com/comfyanonymous/ComfyUI/issues"
},
"homepage": "https://github.com/comfyanonymous/ComfyUI#readme",
"devDependencies": {
"@babel/preset-env": "^7.22.20",
"@types/jest": "^29.5.5",
"jest": "^29.7.0",
"jest-environment-jsdom": "^29.7.0"
}
}

87
tests-ui/setup.js Normal file
View File

@ -0,0 +1,87 @@
const { spawn } = require("child_process");
const { resolve } = require("path");
const { existsSync, mkdirSync, writeFileSync } = require("fs");
const http = require("http");
async function setup() {
// Wait up to 30s for it to start
let success = false;
let child;
for (let i = 0; i < 30; i++) {
try {
await new Promise((res, rej) => {
http
.get("http://127.0.0.1:8188/object_info", (resp) => {
let data = "";
resp.on("data", (chunk) => {
data += chunk;
});
resp.on("end", () => {
// Modify the response data to add some checkpoints
const objectInfo = JSON.parse(data);
objectInfo.CheckpointLoaderSimple.input.required.ckpt_name[0] = ["model1.safetensors", "model2.ckpt"];
data = JSON.stringify(objectInfo, undefined, "\t");
const outDir = resolve("./data");
if (!existsSync(outDir)) {
mkdirSync(outDir);
}
const outPath = resolve(outDir, "object_info.json");
console.log(`Writing ${Object.keys(objectInfo).length} nodes to ${outPath}`);
writeFileSync(outPath, data, {
encoding: "utf8",
});
res();
});
})
.on("error", rej);
});
success = true;
break;
} catch (error) {
console.log(i + "/30", error);
if (i === 0) {
// Start the server on first iteration if it fails to connect
console.log("Starting ComfyUI server...");
let python = resolve("../../python_embeded/python.exe");
let args;
let cwd;
if (existsSync(python)) {
args = ["-s", "ComfyUI/main.py"];
cwd = "../..";
} else {
python = "python";
args = ["main.py"];
cwd = "..";
}
args.push("--cpu");
console.log(python, ...args);
child = spawn(python, args, { cwd });
child.on("error", (err) => {
console.log(`Server error (${err})`);
i = 30;
});
child.on("exit", (code) => {
if (!success) {
console.log(`Server exited (${code})`);
i = 30;
}
});
}
await new Promise((r) => {
setTimeout(r, 1000);
});
}
}
child?.kill();
if (!success) {
throw new Error("Waiting for server failed...");
}
}
setup();

View File

@ -0,0 +1,319 @@
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const { start, makeNodeDef, checkBeforeAndAfterReload, assertNotNullOrUndefined } = require("../utils");
const lg = require("../utils/litegraph");
/**
* @typedef { import("../utils/ezgraph") } Ez
* @typedef { ReturnType<Ez["Ez"]["graph"]>["ez"] } EzNodeFactory
*/
/**
* @param { EzNodeFactory } ez
* @param { InstanceType<Ez["EzGraph"]> } graph
* @param { InstanceType<Ez["EzInput"]> } input
* @param { string } widgetType
* @param { boolean } hasControlWidget
* @returns
*/
async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasControlWidget) {
// Connect to primitive and ensure its still connected after
let primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(input);
await checkBeforeAndAfterReload(graph, async () => {
primitive = graph.find(primitive);
let { connections } = primitive.outputs[0];
expect(connections).toHaveLength(1);
expect(connections[0].targetNode.id).toBe(input.node.node.id);
// Ensure widget is correct type
const valueWidget = primitive.widgets.value;
expect(valueWidget.widget.type).toBe(widgetType);
// Check if control_after_generate should be added
if (hasControlWidget) {
const controlWidget = primitive.widgets.control_after_generate;
expect(controlWidget.widget.type).toBe("combo");
}
// Ensure we dont have other widgets
expect(primitive.node.widgets).toHaveLength(1 + +!!hasControlWidget);
});
return primitive;
}
describe("widget inputs", () => {
beforeEach(() => {
lg.setup(global);
});
afterEach(() => {
lg.teardown(global);
});
[
{ name: "int", type: "INT", widget: "number", control: true },
{ name: "float", type: "FLOAT", widget: "number", control: true },
{ name: "text", type: "STRING" },
{
name: "customtext",
type: "STRING",
opt: { multiline: true },
},
{ name: "toggle", type: "BOOLEAN" },
{ name: "combo", type: ["a", "b", "c"], control: true },
].forEach((c) => {
test(`widget conversion + primitive works on ${c.name}`, async () => {
const { ez, graph } = await start({
mockNodeDefs: makeNodeDef("TestNode", { [c.name]: [c.type, c.opt ?? {}] }),
});
// Create test node and convert to input
const n = ez.TestNode();
const w = n.widgets[c.name];
w.convertToInput();
expect(w.isConvertedToInput).toBeTruthy();
const input = w.getConvertedInput();
expect(input).toBeTruthy();
// @ts-ignore : input is valid here
await connectPrimitiveAndReload(ez, graph, input, c.widget ?? c.name, c.control);
});
});
test("converted widget works after reload", async () => {
const { ez, graph } = await start();
let n = ez.CheckpointLoaderSimple();
const inputCount = n.inputs.length;
// Convert ckpt name to an input
n.widgets.ckpt_name.convertToInput();
expect(n.widgets.ckpt_name.isConvertedToInput).toBeTruthy();
expect(n.inputs.ckpt_name).toBeTruthy();
expect(n.inputs.length).toEqual(inputCount + 1);
// Convert back to widget and ensure input is removed
n.widgets.ckpt_name.convertToWidget();
expect(n.widgets.ckpt_name.isConvertedToInput).toBeFalsy();
expect(n.inputs.ckpt_name).toBeFalsy();
expect(n.inputs.length).toEqual(inputCount);
// Convert again and reload the graph to ensure it maintains state
n.widgets.ckpt_name.convertToInput();
expect(n.inputs.length).toEqual(inputCount + 1);
const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", true);
// Disconnect & reconnect
primitive.outputs[0].connections[0].disconnect();
let { connections } = primitive.outputs[0];
expect(connections).toHaveLength(0);
primitive.outputs[0].connectTo(n.inputs.ckpt_name);
({ connections } = primitive.outputs[0]);
expect(connections).toHaveLength(1);
expect(connections[0].targetNode.id).toBe(n.node.id);
// Convert back to widget and ensure input is removed
n.widgets.ckpt_name.convertToWidget();
expect(n.widgets.ckpt_name.isConvertedToInput).toBeFalsy();
expect(n.inputs.ckpt_name).toBeFalsy();
expect(n.inputs.length).toEqual(inputCount);
});
test("converted widget works on clone", async () => {
const { graph, ez } = await start();
let n = ez.CheckpointLoaderSimple();
// Convert the widget to an input
n.widgets.ckpt_name.convertToInput();
expect(n.widgets.ckpt_name.isConvertedToInput).toBeTruthy();
// Clone the node
n.menu["Clone"].call();
expect(graph.nodes).toHaveLength(2);
const clone = graph.nodes[1];
expect(clone.id).not.toEqual(n.id);
// Ensure the clone has an input
expect(clone.widgets.ckpt_name.isConvertedToInput).toBeTruthy();
expect(clone.inputs.ckpt_name).toBeTruthy();
// Ensure primitive connects to both nodes
let primitive = ez.PrimitiveNode();
primitive.outputs[0].connectTo(n.inputs.ckpt_name);
primitive.outputs[0].connectTo(clone.inputs.ckpt_name);
expect(primitive.outputs[0].connections).toHaveLength(2);
// Convert back to widget and ensure input is removed
clone.widgets.ckpt_name.convertToWidget();
expect(clone.widgets.ckpt_name.isConvertedToInput).toBeFalsy();
expect(clone.inputs.ckpt_name).toBeFalsy();
});
test("shows missing node error on custom node with converted input", async () => {
const { graph } = await start();
const dialogShow = jest.spyOn(graph.app.ui.dialog, "show");
await graph.app.loadGraphData({
last_node_id: 3,
last_link_id: 4,
nodes: [
{
id: 1,
type: "TestNode",
pos: [41.87329101561909, 389.7381480823742],
size: { 0: 220, 1: 374 },
flags: {},
order: 1,
mode: 0,
inputs: [{ name: "test", type: "FLOAT", link: 4, widget: { name: "test" }, slot_index: 0 }],
outputs: [],
properties: { "Node name for S&R": "TestNode" },
widgets_values: [1],
},
{
id: 3,
type: "PrimitiveNode",
pos: [-312, 433],
size: { 0: 210, 1: 82 },
flags: {},
order: 0,
mode: 0,
outputs: [{ links: [4], widget: { name: "test" } }],
title: "test",
properties: {},
},
],
links: [[4, 3, 0, 1, 6, "FLOAT"]],
groups: [],
config: {},
extra: {},
version: 0.4,
});
expect(dialogShow).toBeCalledTimes(1);
expect(dialogShow.mock.calls[0][0]).toContain("the following node types were not found");
expect(dialogShow.mock.calls[0][0]).toContain("TestNode");
});
test("defaultInput widgets can be converted back to inputs", async () => {
const { graph, ez } = await start({
mockNodeDefs: makeNodeDef("TestNode", { example: ["INT", { defaultInput: true }] }),
});
// Create test node and ensure it starts as an input
let n = ez.TestNode();
let w = n.widgets.example;
expect(w.isConvertedToInput).toBeTruthy();
let input = w.getConvertedInput();
expect(input).toBeTruthy();
// Ensure it can be converted to
w.convertToWidget();
expect(w.isConvertedToInput).toBeFalsy();
expect(n.inputs.length).toEqual(0);
// and from
w.convertToInput();
expect(w.isConvertedToInput).toBeTruthy();
input = w.getConvertedInput();
// Reload and ensure it still only has 1 converted widget
if (!assertNotNullOrUndefined(input)) return;
await connectPrimitiveAndReload(ez, graph, input, "number", true);
n = graph.find(n);
expect(n.widgets).toHaveLength(1);
w = n.widgets.example;
expect(w.isConvertedToInput).toBeTruthy();
// Convert back to widget and ensure it is still a widget after reload
w.convertToWidget();
await graph.reload();
n = graph.find(n);
expect(n.widgets).toHaveLength(1);
expect(n.widgets[0].isConvertedToInput).toBeFalsy();
expect(n.inputs.length).toEqual(0);
});
test("forceInput widgets can not be converted back to inputs", async () => {
const { graph, ez } = await start({
mockNodeDefs: makeNodeDef("TestNode", { example: ["INT", { forceInput: true }] }),
});
// Create test node and ensure it starts as an input
let n = ez.TestNode();
let w = n.widgets.example;
expect(w.isConvertedToInput).toBeTruthy();
const input = w.getConvertedInput();
expect(input).toBeTruthy();
// Convert to widget should error
expect(() => w.convertToWidget()).toThrow();
// Reload and ensure it still only has 1 converted widget
if (assertNotNullOrUndefined(input)) {
await connectPrimitiveAndReload(ez, graph, input, "number", true);
n = graph.find(n);
expect(n.widgets).toHaveLength(1);
expect(n.widgets.example.isConvertedToInput).toBeTruthy();
}
});
test("primitive can connect to matching combos on converted widgets", async () => {
const { ez } = await start({
mockNodeDefs: {
...makeNodeDef("TestNode1", { example: [["A", "B", "C"], { forceInput: true }] }),
...makeNodeDef("TestNode2", { example: [["A", "B", "C"], { forceInput: true }] }),
},
});
const n1 = ez.TestNode1();
const n2 = ez.TestNode2();
const p = ez.PrimitiveNode();
p.outputs[0].connectTo(n1.inputs[0]);
p.outputs[0].connectTo(n2.inputs[0]);
expect(p.outputs[0].connections).toHaveLength(2);
const valueWidget = p.widgets.value;
expect(valueWidget.widget.type).toBe("combo");
expect(valueWidget.widget.options.values).toEqual(["A", "B", "C"]);
});
test("primitive can not connect to non matching combos on converted widgets", async () => {
const { ez } = await start({
mockNodeDefs: {
...makeNodeDef("TestNode1", { example: [["A", "B", "C"], { forceInput: true }] }),
...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true }] }),
},
});
const n1 = ez.TestNode1();
const n2 = ez.TestNode2();
const p = ez.PrimitiveNode();
p.outputs[0].connectTo(n1.inputs[0]);
expect(() => p.outputs[0].connectTo(n2.inputs[0])).toThrow();
expect(p.outputs[0].connections).toHaveLength(1);
});
test("combo output can not connect to non matching combos list input", async () => {
const { ez } = await start({
mockNodeDefs: {
...makeNodeDef("TestNode1", {}, [["A", "B"]]),
...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true}] }),
...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true}] }),
},
});
const n1 = ez.TestNode1();
const n2 = ez.TestNode2();
const n3 = ez.TestNode3();
n1.outputs[0].connectTo(n2.inputs[0]);
expect(() => n1.outputs[0].connectTo(n3.inputs[0])).toThrow();
});
});

417
tests-ui/utils/ezgraph.js Normal file
View File

@ -0,0 +1,417 @@
// @ts-check
/// <reference path="../../web/types/litegraph.d.ts" />
/**
* @typedef { import("../../web/scripts/app")["app"] } app
* @typedef { import("../../web/types/litegraph") } LG
* @typedef { import("../../web/types/litegraph").IWidget } IWidget
* @typedef { import("../../web/types/litegraph").ContextMenuItem } ContextMenuItem
* @typedef { import("../../web/types/litegraph").INodeInputSlot } INodeInputSlot
* @typedef { import("../../web/types/litegraph").INodeOutputSlot } INodeOutputSlot
* @typedef { InstanceType<LG["LGraphNode"]> & { widgets?: Array<IWidget> } } LGNode
* @typedef { (...args: EzOutput[] | [...EzOutput[], Record<string, unknown>]) => EzNode } EzNodeFactory
*/
export class EzConnection {
/** @type { app } */
app;
/** @type { InstanceType<LG["LLink"]> } */
link;
get originNode() {
return new EzNode(this.app, this.app.graph.getNodeById(this.link.origin_id));
}
get originOutput() {
return this.originNode.outputs[this.link.origin_slot];
}
get targetNode() {
return new EzNode(this.app, this.app.graph.getNodeById(this.link.target_id));
}
get targetInput() {
return this.targetNode.inputs[this.link.target_slot];
}
/**
* @param { app } app
* @param { InstanceType<LG["LLink"]> } link
*/
constructor(app, link) {
this.app = app;
this.link = link;
}
disconnect() {
this.targetInput.disconnect();
}
}
export class EzSlot {
/** @type { EzNode } */
node;
/** @type { number } */
index;
/**
* @param { EzNode } node
* @param { number } index
*/
constructor(node, index) {
this.node = node;
this.index = index;
}
}
export class EzInput extends EzSlot {
/** @type { INodeInputSlot } */
input;
/**
* @param { EzNode } node
* @param { number } index
* @param { INodeInputSlot } input
*/
constructor(node, index, input) {
super(node, index);
this.input = input;
}
disconnect() {
this.node.node.disconnectInput(this.index);
}
}
export class EzOutput extends EzSlot {
/** @type { INodeOutputSlot } */
output;
/**
* @param { EzNode } node
* @param { number } index
* @param { INodeOutputSlot } output
*/
constructor(node, index, output) {
super(node, index);
this.output = output;
}
get connections() {
return (this.node.node.outputs?.[this.index]?.links ?? []).map(
(l) => new EzConnection(this.node.app, this.node.app.graph.links[l])
);
}
/**
* @param { EzInput } input
*/
connectTo(input) {
if (!input) throw new Error("Invalid input");
/**
* @type { LG["LLink"] | null }
*/
const link = this.node.node.connect(this.index, input.node.node, input.index);
if (!link) {
const inp = input.input;
const inName = inp.name || inp.label || inp.type;
throw new Error(
`Connecting from ${input.node.node.type}[${inName}#${input.index}] -> ${this.node.node.type}[${
this.output.name ?? this.output.type
}#${this.index}] failed.`
);
}
return link;
}
}
export class EzNodeMenuItem {
/** @type { EzNode } */
node;
/** @type { number } */
index;
/** @type { ContextMenuItem } */
item;
/**
* @param { EzNode } node
* @param { number } index
* @param { ContextMenuItem } item
*/
constructor(node, index, item) {
this.node = node;
this.index = index;
this.item = item;
}
call(selectNode = true) {
if (!this.item?.callback) throw new Error(`Menu Item ${this.item?.content ?? "[null]"} has no callback.`);
if (selectNode) {
this.node.select();
}
this.item.callback.call(this.node.node, undefined, undefined, undefined, undefined, this.node.node);
}
}
export class EzWidget {
/** @type { EzNode } */
node;
/** @type { number } */
index;
/** @type { IWidget } */
widget;
/**
* @param { EzNode } node
* @param { number } index
* @param { IWidget } widget
*/
constructor(node, index, widget) {
this.node = node;
this.index = index;
this.widget = widget;
}
get value() {
return this.widget.value;
}
set value(v) {
this.widget.value = v;
}
get isConvertedToInput() {
// @ts-ignore : this type is valid for converted widgets
return this.widget.type === "converted-widget";
}
getConvertedInput() {
if (!this.isConvertedToInput) throw new Error(`Widget ${this.widget.name} is not converted to input.`);
return this.node.inputs.find((inp) => inp.input["widget"]?.name === this.widget.name);
}
convertToWidget() {
if (!this.isConvertedToInput)
throw new Error(`Widget ${this.widget.name} cannot be converted as it is already a widget.`);
this.node.menu[`Convert ${this.widget.name} to widget`].call();
}
convertToInput() {
if (this.isConvertedToInput)
throw new Error(`Widget ${this.widget.name} cannot be converted as it is already an input.`);
this.node.menu[`Convert ${this.widget.name} to input`].call();
}
}
export class EzNode {
/** @type { app } */
app;
/** @type { LGNode } */
node;
/**
* @param { app } app
* @param { LGNode } node
*/
constructor(app, node) {
this.app = app;
this.node = node;
}
get id() {
return this.node.id;
}
get inputs() {
return this.#makeLookupArray("inputs", "name", EzInput);
}
get outputs() {
return this.#makeLookupArray("outputs", "name", EzOutput);
}
get widgets() {
return this.#makeLookupArray("widgets", "name", EzWidget);
}
get menu() {
return this.#makeLookupArray(() => this.app.canvas.getNodeMenuOptions(this.node), "content", EzNodeMenuItem);
}
select() {
this.app.canvas.selectNode(this.node);
}
// /**
// * @template { "inputs" | "outputs" } T
// * @param { T } type
// * @returns { Record<string, type extends "inputs" ? EzInput : EzOutput> & (type extends "inputs" ? EzInput [] : EzOutput[]) }
// */
// #getSlotItems(type) {
// // @ts-ignore : these items are correct
// return (this.node[type] ?? []).reduce((p, s, i) => {
// if (s.name in p) {
// throw new Error(`Unable to store input ${s.name} on array as name conflicts.`);
// }
// // @ts-ignore
// p.push((p[s.name] = new (type === "inputs" ? EzInput : EzOutput)(this, i, s)));
// return p;
// }, Object.assign([], { $: this }));
// }
/**
* @template { { new(node: EzNode, index: number, obj: any): any } } T
* @param { "inputs" | "outputs" | "widgets" | (() => Array<unknown>) } nodeProperty
* @param { string } nameProperty
* @param { T } ctor
* @returns { Record<string, InstanceType<T>> & Array<InstanceType<T>> }
*/
#makeLookupArray(nodeProperty, nameProperty, ctor) {
const items = typeof nodeProperty === "function" ? nodeProperty() : this.node[nodeProperty];
// @ts-ignore
return (items ?? []).reduce((p, s, i) => {
if (!s) return p;
const name = s[nameProperty];
// @ts-ignore
if (!name || name in p) {
throw new Error(`Unable to store ${nodeProperty} ${name} on array as name conflicts.`);
}
// @ts-ignore
p.push((p[name] = new ctor(this, i, s)));
return p;
}, Object.assign([], { $: this }));
}
}
export class EzGraph {
/** @type { app } */
app;
/**
* @param { app } app
*/
constructor(app) {
this.app = app;
}
get nodes() {
return this.app.graph._nodes.map((n) => new EzNode(this.app, n));
}
clear() {
this.app.graph.clear();
}
arrange() {
this.app.graph.arrange();
}
stringify() {
return JSON.stringify(this.app.graph.serialize(), undefined, "\t");
}
/**
* @param { number | LGNode | EzNode } obj
* @returns { EzNode }
*/
find(obj) {
let match;
let id;
if (typeof obj === "number") {
id = obj;
} else {
id = obj.id;
}
match = this.app.graph.getNodeById(id);
if (!match) {
throw new Error(`Unable to find node with ID ${id}.`);
}
return new EzNode(this.app, match);
}
/**
* @returns { Promise<void> }
*/
reload() {
const graph = JSON.parse(JSON.stringify(this.app.graph.serialize()));
return new Promise((r) => {
this.app.graph.clear();
setTimeout(async () => {
await this.app.loadGraphData(graph);
r();
}, 10);
});
}
}
export const Ez = {
/**
* Quickly build and interact with a ComfyUI graph
* @example
* const { ez, graph } = Ez.graph(app);
* graph.clear();
* const [model, clip, vae] = ez.CheckpointLoaderSimple();
* const [pos] = ez.CLIPTextEncode(clip, { text: "positive" });
* const [neg] = ez.CLIPTextEncode(clip, { text: "negative" });
* const [latent] = ez.KSampler(model, pos, neg, ...ez.EmptyLatentImage());
* const [image] = ez.VAEDecode(latent, vae);
* const saveNode = ez.SaveImage(image).node;
* console.log(saveNode);
* graph.arrange();
* @param { app } app
* @param { LG["LiteGraph"] } LiteGraph
* @param { LG["LGraphCanvas"] } LGraphCanvas
* @param { boolean } clearGraph
* @returns { { graph: EzGraph, ez: Record<string, EzNodeFactory> } }
*/
graph(app, LiteGraph = window["LiteGraph"], LGraphCanvas = window["LGraphCanvas"], clearGraph = true) {
// Always set the active canvas so things work
LGraphCanvas.active_canvas = app.canvas;
if (clearGraph) {
app.graph.clear();
}
// @ts-ignore : this proxy handles utility methods & node creation
const factory = new Proxy(
{},
{
get(_, p) {
if (typeof p !== "string") throw new Error("Invalid node");
const node = LiteGraph.createNode(p);
if (!node) throw new Error(`Unknown node "${p}"`);
app.graph.add(node);
/**
* @param {Parameters<EzNodeFactory>} args
*/
return function (...args) {
const ezNode = new EzNode(app, node);
const inputs = ezNode.inputs;
let slot = 0;
for (const arg of args) {
if (arg instanceof EzOutput) {
arg.connectTo(inputs[slot++]);
} else {
for (const k in arg) {
ezNode.widgets[k].value = arg[k];
}
}
}
return ezNode;
};
},
}
);
return { graph: new EzGraph(app), ez: factory };
},
};

71
tests-ui/utils/index.js Normal file
View File

@ -0,0 +1,71 @@
const { mockApi } = require("./setup");
const { Ez } = require("./ezgraph");
/**
*
* @param { Parameters<mockApi>[0] } config
* @returns
*/
export async function start(config = undefined) {
mockApi(config);
const { app } = require("../../web/scripts/app");
await app.setup();
return Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]);
}
/**
* @param { ReturnType<Ez["graph"]>["graph"] } graph
* @param { (hasReloaded: boolean) => (Promise<void> | void) } cb
*/
export async function checkBeforeAndAfterReload(graph, cb) {
await cb(false);
await graph.reload();
await cb(true);
}
/**
* @param { string } name
* @param { Record<string, string | [string | string[], any]> } input
* @param { (string | string[])[] | Record<string, string | string[]> } output
* @returns { Record<string, import("../../web/types/comfy").ComfyObjectInfo> }
*/
export function makeNodeDef(name, input, output = {}) {
const nodeDef = {
name,
category: "test",
output: [],
output_name: [],
output_is_list: [],
input: {
required: {}
},
};
for(const k in input) {
nodeDef.input.required[k] = typeof input[k] === "string" ? [input[k], {}] : [...input[k]];
}
if(output instanceof Array) {
output = output.reduce((p, c) => {
p[c] = c;
return p;
}, {})
}
for(const k in output) {
nodeDef.output.push(output[k]);
nodeDef.output_name.push(k);
nodeDef.output_is_list.push(false);
}
return { [name]: nodeDef };
}
/**
/**
* @template { any } T
* @param { T } x
* @returns { x is Exclude<T, null | undefined> }
*/
export function assertNotNullOrUndefined(x) {
expect(x).not.toEqual(null);
expect(x).not.toEqual(undefined);
return true;
}

View File

@ -0,0 +1,36 @@
const fs = require("fs");
const path = require("path");
const { nop } = require("../utils/nopProxy");
function forEachKey(cb) {
for (const k of [
"LiteGraph",
"LGraph",
"LLink",
"LGraphNode",
"LGraphGroup",
"DragAndScale",
"LGraphCanvas",
"ContextMenu",
]) {
cb(k);
}
}
export function setup(ctx) {
const lg = fs.readFileSync(path.resolve("../web/lib/litegraph.core.js"), "utf-8");
const globalTemp = {};
(function (console) {
eval(lg);
}).call(globalTemp, nop);
forEachKey((k) => (ctx[k] = globalTemp[k]));
require(path.resolve("../web/lib/litegraph.extensions.js"));
}
export function teardown(ctx) {
forEachKey((k) => delete ctx[k]);
// Clear document after each run
document.getElementsByTagName("html")[0].innerHTML = "";
}

View File

@ -0,0 +1,6 @@
export const nop = new Proxy(function () {}, {
get: () => nop,
set: () => true,
apply: () => nop,
construct: () => nop,
});

45
tests-ui/utils/setup.js Normal file
View File

@ -0,0 +1,45 @@
require("../../web/scripts/api");
const fs = require("fs");
const path = require("path");
function* walkSync(dir) {
const files = fs.readdirSync(dir, { withFileTypes: true });
for (const file of files) {
if (file.isDirectory()) {
yield* walkSync(path.join(dir, file.name));
} else {
yield path.join(dir, file.name);
}
}
}
/**
* @typedef { import("../../web/types/comfy").ComfyObjectInfo } ComfyObjectInfo
*/
/**
* @param { { mockExtensions?: string[], mockNodeDefs?: Record<string, ComfyObjectInfo> } } config
*/
export function mockApi({ mockExtensions, mockNodeDefs } = {}) {
if (!mockExtensions) {
mockExtensions = Array.from(walkSync(path.resolve("../web/extensions/core")))
.filter((x) => x.endsWith(".js"))
.map((x) => path.relative(path.resolve("../web"), x));
}
if (!mockNodeDefs) {
mockNodeDefs = JSON.parse(fs.readFileSync(path.resolve("./data/object_info.json")));
}
jest.mock("../../web/scripts/api", () => ({
get api() {
return {
addEventListener: jest.fn(),
getSystemStats: jest.fn(),
getExtensions: jest.fn(() => mockExtensions),
getNodeDefs: jest.fn(() => mockNodeDefs),
init: jest.fn(),
apiURL: jest.fn((x) => "../../web/" + x),
};
},
}));
}

View File

@ -5,6 +5,61 @@ function setNodeMode(node, mode) {
node.graph.change();
}
function addNodesToGroup(group, nodes=[]) {
var x1, y1, x2, y2;
var nx1, ny1, nx2, ny2;
var node;
x1 = y1 = x2 = y2 = -1;
nx1 = ny1 = nx2 = ny2 = -1;
for (var n of [group._nodes, nodes]) {
for (var i in n) {
node = n[i]
nx1 = node.pos[0]
ny1 = node.pos[1]
nx2 = node.pos[0] + node.size[0]
ny2 = node.pos[1] + node.size[1]
if (node.type != "Reroute") {
ny1 -= LiteGraph.NODE_TITLE_HEIGHT;
}
if (node.flags?.collapsed) {
ny2 = ny1 + LiteGraph.NODE_TITLE_HEIGHT;
if (node?._collapsed_width) {
nx2 = nx1 + Math.round(node._collapsed_width);
}
}
if (x1 == -1 || nx1 < x1) {
x1 = nx1;
}
if (y1 == -1 || ny1 < y1) {
y1 = ny1;
}
if (x2 == -1 || nx2 > x2) {
x2 = nx2;
}
if (y2 == -1 || ny2 > y2) {
y2 = ny2;
}
}
}
var padding = 10;
y1 = y1 - Math.round(group.font_size * 1.4);
group.pos = [x1 - padding, y1 - padding];
group.size = [x2 - x1 + padding * 2, y2 - y1 + padding * 2];
}
app.registerExtension({
name: "Comfy.GroupOptions",
setup() {
@ -14,6 +69,17 @@ app.registerExtension({
const options = orig.apply(this, arguments);
const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]);
if (!group) {
options.push({
content: "Add Group For Selected Nodes",
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
callback: () => {
var group = new LiteGraph.LGraphGroup();
addNodesToGroup(group, this.selected_nodes)
app.canvas.graph.add(group);
this.graph.change();
}
});
return options;
}
@ -21,6 +87,15 @@ app.registerExtension({
group.recomputeInsideNodes();
const nodesInGroup = group._nodes;
options.push({
content: "Add Selected Nodes To Group",
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
callback: () => {
addNodesToGroup(group, this.selected_nodes)
this.graph.change();
}
});
// No nodes in group, return default options
if (nodesInGroup.length === 0) {
return options;
@ -38,6 +113,23 @@ app.registerExtension({
}
}
options.push({
content: "Fit Group To Nodes",
callback: () => {
addNodesToGroup(group)
this.graph.change();
}
});
options.push({
content: "Select Nodes",
callback: () => {
this.selectNodes(nodesInGroup);
this.graph.change();
this.canvas.focus();
}
});
// Modes
// 0: Always
// 1: On Event

View File

@ -22,6 +22,15 @@ class ManageTemplates extends ComfyDialog {
super();
this.element.classList.add("comfy-manage-templates");
this.templates = this.load();
this.importInput = $el("input", {
type: "file",
accept: ".json",
multiple: true,
style: {display: "none"},
parent: document.body,
onchange: () => this.importAll(),
});
}
createButtons() {
@ -34,6 +43,22 @@ class ManageTemplates extends ComfyDialog {
onclick: () => this.save(),
})
);
btns.unshift(
$el("button", {
type: "button",
textContent: "Export",
onclick: () => this.exportAll(),
})
);
btns.unshift(
$el("button", {
type: "button",
textContent: "Import",
onclick: () => {
this.importInput.click();
},
})
);
return btns;
}
@ -69,6 +94,52 @@ class ManageTemplates extends ComfyDialog {
localStorage.setItem(id, JSON.stringify(this.templates));
}
async importAll() {
for (const file of this.importInput.files) {
if (file.type === "application/json" || file.name.endsWith(".json")) {
const reader = new FileReader();
reader.onload = async () => {
var importFile = JSON.parse(reader.result);
if (importFile && importFile?.templates) {
for (const template of importFile.templates) {
if (template?.name && template?.data) {
this.templates.push(template);
}
}
this.store();
}
};
await reader.readAsText(file);
}
}
this.importInput.value = null;
this.close();
}
exportAll() {
if (this.templates.length == 0) {
alert("No templates to export.");
return;
}
const json = JSON.stringify({templates: this.templates}, null, 2); // convert the data to a JSON string
const blob = new Blob([json], {type: "application/json"});
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: "node_templates.json",
style: {display: "none"},
parent: document.body,
});
a.click();
setTimeout(function () {
a.remove();
window.URL.revokeObjectURL(url);
}, 0);
}
show() {
// Show list of template names + delete button
super.show(
@ -97,19 +168,48 @@ class ManageTemplates extends ComfyDialog {
}),
]
),
$el("button", {
textContent: "Delete",
style: {
fontSize: "12px",
color: "red",
fontWeight: "normal",
},
onclick: (e) => {
nameInput.value = "";
e.target.style.display = "none";
e.target.previousElementSibling.style.display = "none";
},
}),
$el(
"div",
{},
[
$el("button", {
textContent: "Export",
style: {
fontSize: "12px",
fontWeight: "normal",
},
onclick: (e) => {
const json = JSON.stringify({templates: [t]}, null, 2); // convert the data to a JSON string
const blob = new Blob([json], {type: "application/json"});
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: (nameInput.value || t.name) + ".json",
style: {display: "none"},
parent: document.body,
});
a.click();
setTimeout(function () {
a.remove();
window.URL.revokeObjectURL(url);
}, 0);
},
}),
$el("button", {
textContent: "Delete",
style: {
fontSize: "12px",
color: "red",
fontWeight: "normal",
},
onclick: (e) => {
nameInput.value = "";
e.target.parentElement.style.display = "none";
e.target.parentElement.previousElementSibling.style.display = "none";
},
}),
]
),
];
})
)
@ -164,19 +264,17 @@ app.registerExtension({
},
}));
if (subItems.length) {
subItems.push(null, {
content: "Manage",
callback: () => manage.show(),
});
subItems.push(null, {
content: "Manage",
callback: () => manage.show(),
});
options.push({
content: "Node Templates",
submenu: {
options: subItems,
},
});
}
options.push({
content: "Node Templates",
submenu: {
options: subItems,
},
});
return options;
};

View File

@ -100,6 +100,27 @@ function getWidgetType(config) {
return { type };
}
function isValidCombo(combo, obj) {
// New input isnt a combo
if (!(obj instanceof Array)) {
console.log(`connection rejected: tried to connect combo to ${obj}`);
return false;
}
// New imput combo has a different size
if (combo.length !== obj.length) {
console.log(`connection rejected: combo lists dont match`);
return false;
}
// New input combo has different elements
if (combo.find((v, i) => obj[i] !== v)) {
console.log(`connection rejected: combo lists dont match`);
return false;
}
return true;
}
app.registerExtension({
name: "Comfy.WidgetInputs",
async beforeRegisterNodeDef(nodeType, nodeData, app) {
@ -200,6 +221,10 @@ app.registerExtension({
for (const input of this.inputs) {
if (input.widget && !input.widget[GET_CONFIG]) {
input.widget[GET_CONFIG] = () => getConfig.call(this, input.widget.name);
const w = this.widgets.find((w) => w.name === input.widget.name);
if (w) {
hideWidget(this, w);
}
}
}
}
@ -252,6 +277,28 @@ app.registerExtension({
return r;
};
// Prevent connecting COMBO lists to converted inputs that dont match types
const onConnectInput = nodeType.prototype.onConnectInput;
nodeType.prototype.onConnectInput = function (targetSlot, type, output, originNode, originSlot) {
const v = onConnectInput?.(this, arguments);
// Not a combo, ignore
if (type !== "COMBO") return v;
// Primitive output, allow that to handle
if (originNode.outputs[originSlot].widget) return v;
// Ensure target is also a combo
const targetCombo = this.inputs[targetSlot].widget?.[GET_CONFIG]?.()?.[0];
if (!targetCombo || !(targetCombo instanceof Array)) return v;
// Check they match
const originConfig = originNode.constructor?.nodeData?.output?.[originSlot];
if (!originConfig || !isValidCombo(targetCombo, originConfig)) {
return false;
}
return v;
};
},
registerCustomNodes() {
class PrimitiveNode {
@ -311,7 +358,7 @@ app.registerExtension({
onAfterGraphConfigured() {
if (this.outputs[0].links?.length && !this.widgets?.length) {
this.#onFirstConnection();
if (!this.#onFirstConnection()) return;
// Populate widget values from config data
if (this.widgets) {
@ -382,13 +429,16 @@ app.registerExtension({
widget = input.widget;
}
const { type } = getWidgetType(widget[GET_CONFIG]());
const config = widget[GET_CONFIG]?.();
if (!config) return;
const { type } = getWidgetType(config);
// Update our output to restrict to the widget type
this.outputs[0].type = type;
this.outputs[0].name = type;
this.outputs[0].widget = widget;
this.#createWidget(widget[CONFIG] ?? widget[GET_CONFIG](), theirNode, widget.name, recreating);
this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating);
}
#createWidget(inputData, node, widgetName, recreating) {
@ -413,7 +463,11 @@ app.registerExtension({
}
if (widget.type === "number" || widget.type === "combo") {
addValueControlWidget(this, widget, "fixed");
let control_value = this.widgets_values?.[1];
if (!control_value) {
control_value = "fixed";
}
addValueControlWidget(this, widget, control_value);
}
// When our value changes, update other widgets to reflect our changes
@ -493,21 +547,7 @@ app.registerExtension({
const config2 = input.widget[GET_CONFIG]();
if (config1[0] instanceof Array) {
// New input isnt a combo
if (!(config2[0] instanceof Array)) {
console.log(`connection rejected: tried to connect combo to ${config2[0]}`);
return false;
}
// New imput combo has a different size
if (config1[0].length !== config2[0].length) {
console.log(`connection rejected: combo lists dont match`);
return false;
}
// New input combo has different elements
if (config1[0].find((v, i) => config2[0][i] !== v)) {
console.log(`connection rejected: combo lists dont match`);
return false;
}
if (!isValidCombo(config1[0], config2[0])) return false;
} else if (config1[0] !== config2[0]) {
// Types dont match
console.log(`connection rejected: types dont match`, config1[0], config2[0]);

View File

@ -3799,7 +3799,7 @@
out = out || new Float32Array(4);
out[0] = this.pos[0] - 4;
out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT;
out[2] = this.size[0] + 4;
out[2] = this.flags.collapsed ? (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) : this.size[0] + 4;
out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT;
if (this.onBounding) {

View File

@ -450,8 +450,49 @@ export class ComfyApp {
}
}
function calculateGrid(w, h, n) {
let columns, rows, cellsize;
if (w > h) {
cellsize = h;
columns = Math.ceil(w / cellsize);
rows = Math.ceil(n / columns);
} else {
cellsize = w;
rows = Math.ceil(h / cellsize);
columns = Math.ceil(n / rows);
}
while (columns * rows < n) {
cellsize++;
if (w >= h) {
columns = Math.ceil(w / cellsize);
rows = Math.ceil(n / columns);
} else {
rows = Math.ceil(h / cellsize);
columns = Math.ceil(n / rows);
}
}
const cell_size = Math.min(w/columns, h/rows);
return {cell_size, columns, rows};
}
function is_all_same_aspect_ratio(imgs) {
// assume: imgs.length >= 2
let ratio = imgs[0].naturalWidth/imgs[0].naturalHeight;
for(let i=1; i<imgs.length; i++) {
let this_ratio = imgs[i].naturalWidth/imgs[i].naturalHeight;
if(ratio != this_ratio)
return false;
}
return true;
}
if (this.imgs && this.imgs.length) {
const canvas = graph.list_of_graphcanvas[0];
const canvas = app.graph.list_of_graphcanvas[0];
const mouse = canvas.graph_mouse;
if (!canvas.pointer_is_down && this.pointerDown) {
if (mouse[0] === this.pointerDown.pos[0] && mouse[1] === this.pointerDown.pos[1]) {
@ -460,44 +501,60 @@ export class ComfyApp {
this.pointerDown = null;
}
let w = this.imgs[0].naturalWidth;
let h = this.imgs[0].naturalHeight;
let imageIndex = this.imageIndex;
const numImages = this.imgs.length;
if (numImages === 1 && !imageIndex) {
this.imageIndex = imageIndex = 0;
}
const shiftY = getImageTop(this);
const top = getImageTop(this);
var shiftY = top;
let dw = this.size[0];
let dh = this.size[1];
dh -= shiftY;
if (imageIndex == null) {
let best = 0;
let cellWidth;
let cellHeight;
let cols = 0;
let shiftX = 0;
for (let c = 1; c <= numImages; c++) {
const rows = Math.ceil(numImages / c);
const cW = dw / c;
const cH = dh / rows;
const scaleX = cW / w;
const scaleY = cH / h;
var cellWidth, cellHeight, shiftX, cell_padding, cols;
const scale = Math.min(scaleX, scaleY, 1);
const imageW = w * scale;
const imageH = h * scale;
const area = imageW * imageH * numImages;
const compact_mode = is_all_same_aspect_ratio(this.imgs);
if(!compact_mode) {
// use rectangle cell style and border line
cell_padding = 2;
const { cell_size, columns, rows } = calculateGrid(dw, dh, numImages);
cols = columns;
if (area > best) {
best = area;
cellWidth = imageW;
cellHeight = imageH;
cols = c;
shiftX = c * ((cW - imageW) / 2);
cellWidth = cell_size;
cellHeight = cell_size;
shiftX = (dw-cell_size*cols)/2;
shiftY = (dh-cell_size*rows)/2 + top;
}
else {
cell_padding = 0;
let best = 0;
let w = this.imgs[0].naturalWidth;
let h = this.imgs[0].naturalHeight;
// compact style
for (let c = 1; c <= numImages; c++) {
const rows = Math.ceil(numImages / c);
const cW = dw / c;
const cH = dh / rows;
const scaleX = cW / w;
const scaleY = cH / h;
const scale = Math.min(scaleX, scaleY, 1);
const imageW = w * scale;
const imageH = h * scale;
const area = imageW * imageH * numImages;
if (area > best) {
best = area;
cellWidth = imageW;
cellHeight = imageH;
cols = c;
shiftX = c * ((cW - imageW) / 2);
}
}
}
@ -542,7 +599,14 @@ export class ComfyApp {
let imgWidth = ratio * img.width;
let imgX = col * cellWidth + shiftX + (cellWidth - imgWidth)/2;
ctx.drawImage(img, imgX, imgY, imgWidth, imgHeight);
ctx.drawImage(img, imgX+cell_padding, imgY+cell_padding, imgWidth-cell_padding*2, imgHeight-cell_padding*2);
if(!compact_mode) {
// rectangle cell and border line style
ctx.strokeStyle = "#8F8F8F";
ctx.lineWidth = 1;
ctx.strokeRect(x+cell_padding, y+cell_padding, cellWidth-cell_padding*2, cellHeight-cell_padding*2);
}
ctx.filter = "none";
}
@ -552,6 +616,9 @@ export class ComfyApp {
}
} else {
// Draw individual
let w = this.imgs[imageIndex].naturalWidth;
let h = this.imgs[imageIndex].naturalHeight;
const scaleX = dw / w;
const scaleY = dh / h;
const scale = Math.min(scaleX, scaleY, 1);
@ -594,14 +661,14 @@ export class ComfyApp {
};
if (numImages > 1) {
if (drawButton(x + w - 35, y + h - 35, 30, `${this.imageIndex + 1}/${numImages}`)) {
if (drawButton(dw - 40, dh + top - 40, 30, `${this.imageIndex + 1}/${numImages}`)) {
let i = this.imageIndex + 1 >= numImages ? 0 : this.imageIndex + 1;
if (!this.pointerDown || !this.pointerDown.index === i) {
this.pointerDown = { index: i, pos: [...mouse] };
}
}
if (drawButton(x + w - 35, y + 5, 30, `x`)) {
if (drawButton(dw - 40, top + 10, 30, `x`)) {
if (!this.pointerDown || !this.pointerDown.index === null) {
this.pointerDown = { index: null, pos: [...mouse] };
}
@ -861,6 +928,16 @@ export class ComfyApp {
block_default = true;
}
// Alt + C collapse/uncollapse
if (e.key === 'c' && e.altKey) {
if (this.selected_nodes) {
for (var i in this.selected_nodes) {
this.selected_nodes[i].collapse()
}
}
block_default = true;
}
// Ctrl+C Copy
if ((e.key === 'c') && (e.metaKey || e.ctrlKey)) {
// Trigger onCopy
@ -1339,6 +1416,43 @@ export class ComfyApp {
}
}
loadTemplateData(templateData) {
if (!templateData?.templates) {
return;
}
const old = localStorage.getItem("litegrapheditor_clipboard");
var maxY, nodeBottom, node;
for (const template of templateData.templates) {
if (!template?.data) {
continue;
}
localStorage.setItem("litegrapheditor_clipboard", template.data);
app.canvas.pasteFromClipboard();
// Move mouse position down to paste the next template below
maxY = false;
for (const i in app.canvas.selected_nodes) {
node = app.canvas.selected_nodes[i];
nodeBottom = node.pos[1] + node.size[1];
if (maxY === false || nodeBottom > maxY) {
maxY = nodeBottom;
}
}
app.canvas.graph_mouse[1] = maxY + 50;
}
localStorage.setItem("litegrapheditor_clipboard", old);
}
/**
* Populates the graph with the specified workflow data
* @param {*} graphData A serialized graph object
@ -1564,7 +1678,7 @@ export class ComfyApp {
all_inputs = all_inputs.concat(Object.keys(parent.inputs))
for (let parent_input in all_inputs) {
parent_input = all_inputs[parent_input];
if (parent.inputs[parent_input].type === node.inputs[i].type) {
if (parent.inputs[parent_input]?.type === node.inputs[i].type) {
link = parent.getInputLink(parent_input);
if (link) {
parent = parent.getInputNode(parent_input);
@ -1727,7 +1841,12 @@ export class ComfyApp {
} else if (file.type === "application/json" || file.name?.endsWith(".json")) {
const reader = new FileReader();
reader.onload = () => {
this.loadGraphData(JSON.parse(reader.result));
var jsonContent = JSON.parse(reader.result);
if (jsonContent?.templates) {
this.loadTemplateData(jsonContent);
} else {
this.loadGraphData(jsonContent);
}
};
reader.readAsText(file);
} else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) {

View File

@ -809,7 +809,8 @@ export class ComfyUI {
if (
this.lastQueueSize != 0 &&
status.exec_info.queue_remaining == 0 &&
document.getElementById("autoQueueCheckbox").checked
document.getElementById("autoQueueCheckbox").checked &&
! app.lastExecutionError
) {
app.queuePrompt(0, this.batchCount);
}

View File

@ -84,6 +84,7 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random
if (targetWidget.value > max)
targetWidget.value = max;
targetWidget.callback(targetWidget.value);
}
}
return valueControl;