Merge branch 'comfyanonymous:master' into feat/widgetFeedback

This commit is contained in:
Dr.Lt.Data 2023-10-18 09:53:15 +09:00 committed by GitHub
commit d7ed61d4cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 812 additions and 788 deletions

View File

@ -46,6 +46,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
| Ctrl + S | Save workflow | | Ctrl + S | Save workflow |
| Ctrl + O | Load workflow | | Ctrl + O | Load workflow |
| Ctrl + A | Select all nodes | | Ctrl + A | Select all nodes |
| Alt + C | Collapse/uncollapse selected nodes |
| Ctrl + M | Mute/unmute selected nodes | | Ctrl + M | Mute/unmute selected nodes |
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) | | Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| Delete/Backspace | Delete selected nodes | | Delete/Backspace | Delete selected nodes |
@ -89,6 +90,8 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
Put your VAE in: models/vae Put your VAE in: models/vae
Note: pytorch does not support python 3.12 yet so make sure your python version is 3.11 or earlier.
### AMD GPUs (Linux only) ### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:

View File

@ -34,8 +34,7 @@ class ControlNet(nn.Module):
dims=2, dims=2,
num_classes=None, num_classes=None,
use_checkpoint=False, use_checkpoint=False,
use_fp16=False, dtype=torch.float32,
use_bf16=False,
num_heads=-1, num_heads=-1,
num_head_channels=-1, num_head_channels=-1,
num_heads_upsample=-1, num_heads_upsample=-1,
@ -108,8 +107,7 @@ class ControlNet(nn.Module):
self.conv_resample = conv_resample self.conv_resample = conv_resample
self.num_classes = num_classes self.num_classes = num_classes
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32 self.dtype = dtype
self.dtype = th.bfloat16 if use_bf16 else self.dtype
self.num_heads = num_heads self.num_heads = num_heads
self.num_head_channels = num_head_channels self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample self.num_heads_upsample = num_heads_upsample

View File

@ -39,6 +39,7 @@ parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORI
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.") parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).") parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
@ -52,6 +53,8 @@ fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
fpvae_group = parser.add_mutually_exclusive_group() fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")

View File

@ -292,8 +292,8 @@ def load_controlnet(ckpt_path, model=None):
controlnet_config = None controlnet_config = None
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
use_fp16 = comfy.model_management.should_use_fp16() unet_dtype = comfy.model_management.unet_dtype()
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16) controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config) diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
@ -353,8 +353,8 @@ def load_controlnet(ckpt_path, model=None):
return net return net
if controlnet_config is None: if controlnet_config is None:
use_fp16 = comfy.model_management.should_use_fp16() unet_dtype = comfy.model_management.unet_dtype()
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16, True).unet_config controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
controlnet_config.pop("out_channels") controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
@ -383,8 +383,7 @@ def load_controlnet(ckpt_path, model=None):
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
print(missing, unexpected) print(missing, unexpected)
if use_fp16: control_model = control_model.to(unet_dtype)
control_model = control_model.half()
global_average_pooling = False global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0] filename = os.path.splitext(ckpt_path)[0]

View File

@ -31,6 +31,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
vae = None vae = None
if output_vae: if output_vae:
vae = comfy.sd.VAE(ckpt_path=vae_path) sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
return (unet, clip, vae) return (unet, clip, vae)

View File

@ -20,7 +20,7 @@ class SD15(LatentFormat):
[-0.2829, 0.1762, 0.2721], [-0.2829, 0.1762, 0.2721],
[-0.2120, -0.2616, -0.7177] [-0.2120, -0.2616, -0.7177]
] ]
self.taesd_decoder_name = "taesd_decoder.pth" self.taesd_decoder_name = "taesd_decoder"
class SDXL(LatentFormat): class SDXL(LatentFormat):
def __init__(self): def __init__(self):
@ -32,4 +32,4 @@ class SDXL(LatentFormat):
[ 0.0568, 0.1687, -0.0755], [ 0.0568, 0.1687, -0.0755],
[-0.3112, -0.2359, -0.2076] [-0.3112, -0.2359, -0.2076]
] ]
self.taesd_decoder_name = "taesdxl_decoder.pth" self.taesd_decoder_name = "taesdxl_decoder"

View File

@ -2,67 +2,66 @@ import torch
# import pytorch_lightning as pl # import pytorch_lightning as pl
import torch.nn.functional as F import torch.nn.functional as F
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
from comfy.ldm.modules.diffusionmodules.model import Encoder, Decoder
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from comfy.ldm.util import instantiate_from_config from comfy.ldm.util import instantiate_from_config
from comfy.ldm.modules.ema import LitEma from comfy.ldm.modules.ema import LitEma
# class AutoencoderKL(pl.LightningModule): class DiagonalGaussianRegularizer(torch.nn.Module):
class AutoencoderKL(torch.nn.Module): def __init__(self, sample: bool = True):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False
):
super().__init__() super().__init__()
self.learn_logvar = learn_logvar self.sample = sample
self.image_key = image_key
self.encoder = Encoder(**ddconfig) def get_trainable_parameters(self) -> Any:
self.decoder = Decoder(**ddconfig) yield from ()
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"] def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) log = dict()
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) posterior = DiagonalGaussianDistribution(z)
self.embed_dim = embed_dim if self.sample:
if colorize_nlabels is not None: z = posterior.sample()
assert type(colorize_nlabels)==int else:
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) z = posterior.mode()
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log["kl_loss"] = kl_loss
return z, log
class AbstractAutoencoder(torch.nn.Module):
"""
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
unCLIP models, etc. Hence, it is fairly general, and specific features
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
"""
def __init__(
self,
ema_decay: Union[None, float] = None,
monitor: Union[None, str] = None,
input_key: str = "jpg",
**kwargs,
):
super().__init__()
self.input_key = input_key
self.use_ema = ema_decay is not None
if monitor is not None: if monitor is not None:
self.monitor = monitor self.monitor = monitor
self.use_ema = ema_decay is not None
if self.use_ema: if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay) self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None: def get_input(self, batch) -> Any:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) raise NotImplementedError()
def init_from_ckpt(self, path, ignore_keys=list()): def on_train_batch_end(self, *args, **kwargs):
if path.lower().endswith(".safetensors"): # for EMA computation
import safetensors.torch if self.use_ema:
sd = safetensors.torch.load_file(path, device="cpu") self.model_ema(self)
else:
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
@contextmanager @contextmanager
def ema_scope(self, context=None): def ema_scope(self, context=None):
@ -70,154 +69,159 @@ class AutoencoderKL(torch.nn.Module):
self.model_ema.store(self.parameters()) self.model_ema.store(self.parameters())
self.model_ema.copy_to(self) self.model_ema.copy_to(self)
if context is not None: if context is not None:
print(f"{context}: Switched to EMA weights") logpy.info(f"{context}: Switched to EMA weights")
try: try:
yield None yield None
finally: finally:
if self.use_ema: if self.use_ema:
self.model_ema.restore(self.parameters()) self.model_ema.restore(self.parameters())
if context is not None: if context is not None:
print(f"{context}: Restored training weights") logpy.info(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs): def encode(self, *args, **kwargs) -> torch.Tensor:
if self.use_ema: raise NotImplementedError("encode()-method of abstract base class called")
self.model_ema(self)
def encode(self, x): def decode(self, *args, **kwargs) -> torch.Tensor:
h = self.encoder(x) raise NotImplementedError("decode()-method of abstract base class called")
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z): def instantiate_optimizer_from_config(self, params, lr, cfg):
z = self.post_quant_conv(z) logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
dec = self.decoder(z) return get_obj_from_str(cfg["target"])(
return dec params, lr=lr, **cfg.get("params", dict())
)
def forward(self, input, sample_posterior=True): def configure_optimizers(self) -> Any:
posterior = self.encode(input) raise NotImplementedError()
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx): class AutoencodingEngine(AbstractAutoencoder):
inputs = self.get_input(batch, self.image_key) """
reconstructions, posterior = self(inputs) Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
(we also restore them explicitly as special cases for legacy reasons).
Regularizations such as KL or VQ are moved to the regularizer class.
"""
if optimizer_idx == 0: def __init__(
# train encoder+decoder+logvar self,
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, *args,
last_layer=self.get_last_layer(), split="train") encoder_config: Dict,
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) decoder_config: Dict,
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) regularizer_config: Dict,
return aeloss **kwargs,
):
super().__init__(*args, **kwargs)
if optimizer_idx == 1: self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
# train the discriminator self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, self.regularization: AbstractRegularizer = instantiate_from_config(
last_layer=self.get_last_layer(), split="train") regularizer_config
)
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, postfix=""):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
if self.learn_logvar:
print(f"{self.__class__.__name__}: Learning logvar")
ae_params_list.append(self.loss.logvar)
opt_ae = torch.optim.Adam(ae_params_list,
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self): def get_last_layer(self):
return self.decoder.conv_out.weight return self.decoder.get_last_layer()
@torch.no_grad() def encode(
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): self,
log = dict() x: torch.Tensor,
x = self.get_input(batch, self.image_key) return_reg_log: bool = False,
x = x.to(self.device) unregularized: bool = False,
if not only_inputs: ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
xrec, posterior = self(x) z = self.encoder(x)
if x.shape[1] > 3: if unregularized:
# colorize with random projection return z, dict()
assert xrec.shape[1] > 3 z, reg_log = self.regularization(z)
x = self.to_rgb(x) if return_reg_log:
xrec = self.to_rgb(xrec) return z, reg_log
log["samples"] = self.decode(torch.randn_like(posterior.sample())) return z
log["reconstructions"] = xrec
if log_ema or self.use_ema:
with self.ema_scope():
xrec_ema, posterior_ema = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
log["reconstructions_ema"] = xrec_ema
log["inputs"] = x
return log
def to_rgb(self, x): def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
assert self.image_key == "segmentation" x = self.decoder(z, **kwargs)
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x return x
def forward(
self, x: torch.Tensor, **additional_decode_kwargs
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True)
dec = self.decode(z, **additional_decode_kwargs)
return z, dec, reg_log
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs): class AutoencodingEngineLegacy(AutoencodingEngine):
return x def __init__(self, embed_dim: int, **kwargs):
self.max_batch_size = kwargs.pop("max_batch_size", None)
ddconfig = kwargs.pop("ddconfig")
super().__init__(
encoder_config={
"target": "comfy.ldm.modules.diffusionmodules.model.Encoder",
"params": ddconfig,
},
decoder_config={
"target": "comfy.ldm.modules.diffusionmodules.model.Decoder",
"params": ddconfig,
},
**kwargs,
)
self.quant_conv = torch.nn.Conv2d(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim,
1,
)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
def decode(self, x, *args, **kwargs): def get_autoencoder_params(self) -> list:
return x params = super().get_autoencoder_params()
return params
def quantize(self, x, *args, **kwargs): def encode(
if self.vq_interface: self, x: torch.Tensor, return_reg_log: bool = False
return x, None, [None, None, None] ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
return x if self.max_batch_size is None:
z = self.encoder(x)
z = self.quant_conv(z)
else:
N = x.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
z = list()
for i_batch in range(n_batches):
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
z_batch = self.quant_conv(z_batch)
z.append(z_batch)
z = torch.cat(z, 0)
def forward(self, x, *args, **kwargs): z, reg_log = self.regularization(z)
return x if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
if self.max_batch_size is None:
dec = self.post_quant_conv(z)
dec = self.decoder(dec, **decoder_kwargs)
else:
N = z.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
dec = list()
for i_batch in range(n_batches):
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
dec.append(dec_batch)
dec = torch.cat(dec, 0)
return dec
class AutoencoderKL(AutoencodingEngineLegacy):
def __init__(self, **kwargs):
if "lossconfig" in kwargs:
kwargs["loss_config"] = kwargs.pop("lossconfig")
super().__init__(
regularizer_config={
"target": (
"comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"
)
},
**kwargs,
)

View File

@ -94,253 +94,222 @@ def zero_module(module):
def Normalize(in_channels, dtype=None, device=None): def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None):
h = heads
scale = (q.shape[-1] // heads) ** -0.5
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
class SpatialSelfAttention(nn.Module): # force cast to fp32 to avoid overflowing
def __init__(self, in_channels): if _ATTN_PRECISION =="fp32":
super().__init__() with torch.autocast(enabled=False, device_type = 'cuda'):
self.in_channels = in_channels q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale
self.norm = Normalize(in_channels) del q, k
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x): if exists(mask):
h_ = x mask = rearrange(mask, 'b ... -> b (...)')
h_ = self.norm(h_) max_neg_value = -torch.finfo(sim.dtype).max
q = self.q(h_) mask = repeat(mask, 'b j -> (b h) () j', h=h)
k = self.k(h_) sim.masked_fill_(~mask, max_neg_value)
v = self.v(h_)
# compute attention # attention, what we cannot get enough of
b,c,h,w = q.shape sim = sim.softmax(dim=-1)
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5)) out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
w_ = torch.nn.functional.softmax(w_, dim=2) out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return out
# attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x+h_
class CrossAttentionBirchSan(nn.Module): def attention_sub_quad(query, key, value, heads, mask=None):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): scale = (query.shape[-1] // heads) ** -0.5
super().__init__() query = query.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1)
inner_dim = dim_head * heads key_t = key.transpose(1,2).unflatten(1, (heads, -1)).flatten(end_dim=1)
context_dim = default(context_dim, query_dim) del key
value = value.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1)
self.scale = dim_head ** -0.5 dtype = query.dtype
self.heads = heads upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
if upcast_attention:
bytes_per_token = torch.finfo(torch.float32).bits//8
else:
bytes_per_token = torch.finfo(query.dtype).bits//8
batch_x_heads, q_tokens, _ = query.shape
_, _, k_tokens = key_t.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential( chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout)
)
def forward(self, x, context=None, value=None, mask=None): kv_chunk_size_min = None
h = self.heads
query = self.to_q(x) #not sure at all about the math here
context = default(context, x) #TODO: tweak this
key = self.to_k(context) if mem_free_total > 8192 * 1024 * 1024 * 1.3:
if value is not None: query_chunk_size_x = 1024 * 4
value = self.to_v(value) elif mem_free_total > 4096 * 1024 * 1024 * 1.3:
else: query_chunk_size_x = 1024 * 2
value = self.to_v(context) else:
query_chunk_size_x = 1024
kv_chunk_size_min_x = None
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
if kv_chunk_size_x < 1024:
kv_chunk_size_x = None
del context, x if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
query_chunk_size = q_tokens
kv_chunk_size = k_tokens
else:
query_chunk_size = query_chunk_size_x
kv_chunk_size = kv_chunk_size_x
kv_chunk_size_min = kv_chunk_size_min_x
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1) hidden_states = efficient_dot_product_attention(
key_t = key.transpose(1,2).unflatten(1, (self.heads, -1)).flatten(end_dim=1) query,
del key key_t,
value = value.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1) value,
query_chunk_size=query_chunk_size,
kv_chunk_size=kv_chunk_size,
kv_chunk_size_min=kv_chunk_size_min,
use_checkpoint=False,
upcast_attention=upcast_attention,
)
dtype = query.dtype hidden_states = hidden_states.to(dtype)
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
if upcast_attention:
bytes_per_token = torch.finfo(torch.float32).bits//8
else:
bytes_per_token = torch.finfo(query.dtype).bits//8
batch_x_heads, q_tokens, _ = query.shape
_, _, k_tokens = key_t.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True) hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD def attention_split(q, k, v, heads, mask=None):
scale = (q.shape[-1] // heads) ** -0.5
h = heads
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
kv_chunk_size_min = None r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
#not sure at all about the math here mem_free_total = model_management.get_free_memory(q.device)
#TODO: tweak this
if mem_free_total > 8192 * 1024 * 1024 * 1.3:
query_chunk_size_x = 1024 * 4
elif mem_free_total > 4096 * 1024 * 1024 * 1.3:
query_chunk_size_x = 1024 * 2
else:
query_chunk_size_x = 1024
kv_chunk_size_min_x = None
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
if kv_chunk_size_x < 1024:
kv_chunk_size_x = None
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: gb = 1024 ** 3
# the big matmul fits into our memory limit; do everything in 1 chunk, tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
# i.e. send it down the unchunked fast-path modifier = 3 if q.element_size() == 2 else 2.5
query_chunk_size = q_tokens mem_required = tensor_size * modifier
kv_chunk_size = k_tokens steps = 1
else:
query_chunk_size = query_chunk_size_x
kv_chunk_size = kv_chunk_size_x
kv_chunk_size_min = kv_chunk_size_min_x
hidden_states = efficient_dot_product_attention(
query,
key_t,
value,
query_chunk_size=query_chunk_size,
kv_chunk_size=kv_chunk_size,
kv_chunk_size_min=kv_chunk_size_min,
use_checkpoint=self.training,
upcast_attention=upcast_attention,
)
hidden_states = hidden_states.to(dtype)
hidden_states = hidden_states.unflatten(0, (-1, self.heads)).transpose(1,2).flatten(start_dim=2)
out_proj, dropout = self.to_out
hidden_states = out_proj(hidden_states)
hidden_states = dropout(hidden_states)
return hidden_states
class CrossAttentionDoggettx(nn.Module): if mem_required > mem_free_total:
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
super().__init__() # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
inner_dim = dim_head * heads # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5 if steps > 64:
self.heads = heads max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) first_op_done = False
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) cleared_cache = False
while True:
self.to_out = nn.Sequential( try:
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
nn.Dropout(dropout) for i in range(0, q.shape[1], slice_size):
) end = i + slice_size
if _ATTN_PRECISION =="fp32":
def forward(self, x, context=None, value=None, mask=None): with torch.autocast(enabled=False, device_type = 'cuda'):
h = self.heads s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context)
if value is not None:
v_in = self.to_v(value)
del value
else:
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
mem_free_total = model_management.get_free_memory(q.device)
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
first_op_done = False
cleared_cache = False
while True:
try:
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * self.scale
else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
first_op_done = True
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
break
except model_management.OOM_EXCEPTION as e:
if first_op_done == False:
model_management.soft_empty_cache(True)
if cleared_cache == False:
cleared_cache = True
print("out of memory error, emptying cache and trying again")
continue
steps *= 2
if steps > 64:
raise e
print("out of memory error, increasing steps and trying again", steps)
else: else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
first_op_done = True
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
break
except model_management.OOM_EXCEPTION as e:
if first_op_done == False:
model_management.soft_empty_cache(True)
if cleared_cache == False:
cleared_cache = True
print("out of memory error, emptying cache and trying again")
continue
steps *= 2
if steps > 64:
raise e raise e
print("out of memory error, increasing steps and trying again", steps)
else:
raise e
del q, k, v del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1 del r1
return r2
return self.to_out(r2) def attention_xformers(q, k, v, heads, mask=None):
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], heads, -1)
.permute(0, 2, 1, 3)
.reshape(b * heads, t.shape[1], -1)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, heads, out.shape[1], -1)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], -1)
)
return out
def attention_pytorch(q, k, v, heads, mask=None):
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
return out
optimized_attention = attention_basic
optimized_attention_masked = attention_basic
if model_management.xformers_enabled():
print("Using xformers cross attention")
optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled():
print("Using pytorch cross attention")
optimized_attention = attention_pytorch
else:
if args.use_split_cross_attention:
print("Using split optimization for cross attention")
optimized_attention = attention_split
else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad
if model_management.pytorch_attention_enabled():
optimized_attention_masked = attention_pytorch
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
@ -348,62 +317,6 @@ class CrossAttention(nn.Module):
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
nn.Dropout(dropout)
)
def forward(self, x, context=None, value=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
if value is not None:
v = self.to_v(value)
del value
else:
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=comfy.ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads self.heads = heads
self.dim_head = dim_head self.dim_head = dim_head
@ -412,7 +325,6 @@ class MemoryEfficientCrossAttention(nn.Module):
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None): def forward(self, x, context=None, value=None, mask=None):
q = self.to_q(x) q = self.to_q(x)
@ -424,85 +336,12 @@ class MemoryEfficientCrossAttention(nn.Module):
else: else:
v = self.to_v(context) v = self.to_v(context)
b, _, _ = q.shape if mask is None:
q, k, v = map( out = optimized_attention(q, k, v, self.heads)
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)
class CrossAttentionPytorch(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, value=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
if value is not None:
v = self.to_v(value)
del value
else: else:
v = self.to_v(context) out = optimized_attention_masked(q, k, v, self.heads, mask)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
if exists(mask):
raise NotImplementedError
out = (
out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head)
)
return self.to_out(out) return self.to_out(out)
if model_management.xformers_enabled():
print("Using xformers cross attention")
CrossAttention = MemoryEfficientCrossAttention
elif model_management.pytorch_attention_enabled():
print("Using pytorch cross attention")
CrossAttention = CrossAttentionPytorch
else:
if args.use_split_cross_attention:
print("Using split optimization for cross attention")
CrossAttention = CrossAttentionDoggettx
else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
CrossAttention = CrossAttentionBirchSan
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,

View File

@ -6,7 +6,6 @@ import numpy as np
from einops import rearrange from einops import rearrange
from typing import Optional, Any from typing import Optional, Any
from ..attention import MemoryEfficientCrossAttention
from comfy import model_management from comfy import model_management
import comfy.ops import comfy.ops
@ -194,6 +193,52 @@ def slice_attention(q, k, v):
return r1 return r1
def normal_attention(q, k, v):
# compute attention
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
v = v.reshape(b,c,h*w)
r1 = slice_attention(q, k, v)
h_ = r1.reshape(b,c,h,w)
del r1
return h_
def xformers_attention(q, k, v):
# compute attention
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
(q, k, v),
)
try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
out = out.transpose(1, 2).reshape(B, C, H, W)
except NotImplementedError as e:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
return out
def pytorch_attention(q, k, v):
# compute attention
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
)
try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W)
except model_management.OOM_EXCEPTION as e:
print("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
return out
class AttnBlock(nn.Module): class AttnBlock(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super().__init__() super().__init__()
@ -221,6 +266,16 @@ class AttnBlock(nn.Module):
stride=1, stride=1,
padding=0) padding=0)
if model_management.xformers_enabled_vae():
print("Using xformers attention in VAE")
self.optimized_attention = xformers_attention
elif model_management.pytorch_attention_enabled():
print("Using pytorch attention in VAE")
self.optimized_attention = pytorch_attention
else:
print("Using split attention in VAE")
self.optimized_attention = normal_attention
def forward(self, x): def forward(self, x):
h_ = x h_ = x
h_ = self.norm(h_) h_ = self.norm(h_)
@ -228,161 +283,15 @@ class AttnBlock(nn.Module):
k = self.k(h_) k = self.k(h_)
v = self.v(h_) v = self.v(h_)
# compute attention h_ = self.optimized_attention(q, k, v)
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
v = v.reshape(b,c,h*w)
r1 = slice_attention(q, k, v)
h_ = r1.reshape(b,c,h,w)
del r1
h_ = self.proj_out(h_) h_ = self.proj_out(h_)
return x+h_ return x+h_
class MemoryEfficientAttnBlock(nn.Module):
"""
Uses xformers efficient implementation,
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
Note: this is a single-head self-attention operation
"""
#
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.attention_op: Optional[Any] = None
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
(q, k, v),
)
try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
out = out.transpose(1, 2).reshape(B, C, H, W)
except NotImplementedError as e:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = self.proj_out(out)
return x+out
class MemoryEfficientAttnBlockPytorch(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.attention_op: Optional[Any] = None
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
)
try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W)
except model_management.OOM_EXCEPTION as e:
print("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = self.proj_out(out)
return x+out
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def forward(self, x, context=None, mask=None):
b, c, h, w = x.shape
x = rearrange(x, 'b c h w -> b (h w) c')
out = super().forward(x, context=context, mask=mask)
out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
return x + out
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' return AttnBlock(in_channels)
if model_management.xformers_enabled_vae() and attn_type == "vanilla":
attn_type = "vanilla-xformers"
if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
attn_type = "vanilla-pytorch"
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
assert attn_kwargs is None
return AttnBlock(in_channels)
elif attn_type == "vanilla-xformers":
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
return MemoryEfficientAttnBlock(in_channels)
elif attn_type == "vanilla-pytorch":
return MemoryEfficientAttnBlockPytorch(in_channels)
elif type == "memory-efficient-cross-attn":
attn_kwargs["query_dim"] = in_channels
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
elif attn_type == "none":
return nn.Identity(in_channels)
else:
raise NotImplementedError()
class Model(nn.Module): class Model(nn.Module):
@ -632,7 +541,10 @@ class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
attn_type="vanilla", **ignorekwargs): conv_out_op=comfy.ops.Conv2d,
resnet_op=ResnetBlock,
attn_op=AttnBlock,
**ignorekwargs):
super().__init__() super().__init__()
if use_linear_attn: attn_type = "linear" if use_linear_attn: attn_type = "linear"
self.ch = ch self.ch = ch
@ -661,12 +573,12 @@ class Decoder(nn.Module):
# middle # middle
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, self.mid.block_1 = resnet_op(in_channels=block_in,
out_channels=block_in, out_channels=block_in,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout) dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.attn_1 = attn_op(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, self.mid.block_2 = resnet_op(in_channels=block_in,
out_channels=block_in, out_channels=block_in,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout) dropout=dropout)
@ -678,13 +590,13 @@ class Decoder(nn.Module):
attn = nn.ModuleList() attn = nn.ModuleList()
block_out = ch*ch_mult[i_level] block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1): for i_block in range(self.num_res_blocks+1):
block.append(ResnetBlock(in_channels=block_in, block.append(resnet_op(in_channels=block_in,
out_channels=block_out, out_channels=block_out,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout)) dropout=dropout))
block_in = block_out block_in = block_out
if curr_res in attn_resolutions: if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type)) attn.append(attn_op(block_in))
up = nn.Module() up = nn.Module()
up.block = block up.block = block
up.attn = attn up.attn = attn
@ -695,13 +607,13 @@ class Decoder(nn.Module):
# end # end
self.norm_out = Normalize(block_in) self.norm_out = Normalize(block_in)
self.conv_out = comfy.ops.Conv2d(block_in, self.conv_out = conv_out_op(block_in,
out_ch, out_ch,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
def forward(self, z): def forward(self, z, **kwargs):
#assert z.shape[1:] == self.z_shape[1:] #assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape self.last_z_shape = z.shape
@ -712,16 +624,16 @@ class Decoder(nn.Module):
h = self.conv_in(z) h = self.conv_in(z)
# middle # middle
h = self.mid.block_1(h, temb) h = self.mid.block_1(h, temb, **kwargs)
h = self.mid.attn_1(h) h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, temb) h = self.mid.block_2(h, temb, **kwargs)
# upsampling # upsampling
for i_level in reversed(range(self.num_resolutions)): for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1): for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb) h = self.up[i_level].block[i_block](h, temb, **kwargs)
if len(self.up[i_level].attn) > 0: if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h) h = self.up[i_level].attn[i_block](h, **kwargs)
if i_level != 0: if i_level != 0:
h = self.up[i_level].upsample(h) h = self.up[i_level].upsample(h)
@ -731,7 +643,7 @@ class Decoder(nn.Module):
h = self.norm_out(h) h = self.norm_out(h)
h = nonlinearity(h) h = nonlinearity(h)
h = self.conv_out(h) h = self.conv_out(h, **kwargs)
if self.tanh_out: if self.tanh_out:
h = torch.tanh(h) h = torch.tanh(h)
return h return h

View File

@ -296,8 +296,7 @@ class UNetModel(nn.Module):
dims=2, dims=2,
num_classes=None, num_classes=None,
use_checkpoint=False, use_checkpoint=False,
use_fp16=False, dtype=th.float32,
use_bf16=False,
num_heads=-1, num_heads=-1,
num_head_channels=-1, num_head_channels=-1,
num_heads_upsample=-1, num_heads_upsample=-1,
@ -370,8 +369,7 @@ class UNetModel(nn.Module):
self.conv_resample = conv_resample self.conv_resample = conv_resample
self.num_classes = num_classes self.num_classes = num_classes
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32 self.dtype = dtype
self.dtype = th.bfloat16 if use_bf16 else self.dtype
self.num_heads = num_heads self.num_heads = num_heads
self.num_head_channels = num_head_channels self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample self.num_heads_upsample = num_heads_upsample

View File

@ -14,7 +14,7 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1 count += 1
return count return count
def detect_unet_config(state_dict, key_prefix, use_fp16): def detect_unet_config(state_dict, key_prefix, dtype):
state_dict_keys = list(state_dict.keys()) state_dict_keys = list(state_dict.keys())
unet_config = { unet_config = {
@ -32,7 +32,7 @@ def detect_unet_config(state_dict, key_prefix, use_fp16):
else: else:
unet_config["adm_in_channels"] = None unet_config["adm_in_channels"] = None
unet_config["use_fp16"] = use_fp16 unet_config["dtype"] = dtype
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
@ -116,15 +116,15 @@ def model_config_from_unet_config(unet_config):
print("no match", unet_config) print("no match", unet_config)
return None return None
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16, use_base_if_no_match=False): def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16) unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
model_config = model_config_from_unet_config(unet_config) model_config = model_config_from_unet_config(unet_config)
if model_config is None and use_base_if_no_match: if model_config is None and use_base_if_no_match:
return comfy.supported_models_base.BASE(unet_config) return comfy.supported_models_base.BASE(unet_config)
else: else:
return model_config return model_config
def unet_config_from_diffusers_unet(state_dict, use_fp16): def unet_config_from_diffusers_unet(state_dict, dtype):
match = {} match = {}
attention_resolutions = [] attention_resolutions = []
@ -147,47 +147,47 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1] match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384, 'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64} 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64}
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8} 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8}
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4], 'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4], 'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1} 'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 9, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
@ -203,8 +203,8 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
return unet_config return unet_config
return None return None
def model_config_from_diffusers_unet(state_dict, use_fp16): def model_config_from_diffusers_unet(state_dict, dtype):
unet_config = unet_config_from_diffusers_unet(state_dict, use_fp16) unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
if unet_config is not None: if unet_config is not None:
return model_config_from_unet_config(unet_config) return model_config_from_unet_config(unet_config)
return None return None

View File

@ -154,14 +154,18 @@ def is_nvidia():
return True return True
return False return False
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention:
ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILABLE = False
VAE_DTYPE = torch.float32 VAE_DTYPE = torch.float32
try: try:
if is_nvidia(): if is_nvidia():
torch_version = torch.version.__version__ torch_version = torch.version.__version__
if int(torch_version[0]) >= 2: if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
VAE_DTYPE = torch.bfloat16 VAE_DTYPE = torch.bfloat16
@ -186,7 +190,6 @@ if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True)
XFORMERS_IS_AVAILABLE = False
if args.lowvram: if args.lowvram:
set_vram_to = VRAMState.LOW_VRAM set_vram_to = VRAMState.LOW_VRAM
@ -354,6 +357,8 @@ def load_models_gpu(models, memory_required=0):
current_loaded_models.insert(0, current_loaded_models.pop(index)) current_loaded_models.insert(0, current_loaded_models.pop(index))
models_already_loaded.append(loaded_model) models_already_loaded.append(loaded_model)
else: else:
if hasattr(x, "model"):
print(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model) models_to_load.append(loaded_model)
if len(models_to_load) == 0: if len(models_to_load) == 0:
@ -363,7 +368,7 @@ def load_models_gpu(models, memory_required=0):
free_memory(extra_mem, d, models_already_loaded) free_memory(extra_mem, d, models_already_loaded)
return return
print("loading new") print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
total_memory_required = {} total_memory_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:
@ -405,7 +410,6 @@ def load_model_gpu(model):
def cleanup_models(): def cleanup_models():
to_delete = [] to_delete = []
for i in range(len(current_loaded_models)): for i in range(len(current_loaded_models)):
print(sys.getrefcount(current_loaded_models[i].model))
if sys.getrefcount(current_loaded_models[i].model) <= 2: if sys.getrefcount(current_loaded_models[i].model) <= 2:
to_delete = [i] + to_delete to_delete = [i] + to_delete
@ -444,6 +448,13 @@ def unet_inital_load_device(parameters, dtype):
else: else:
return cpu_dev return cpu_dev
def unet_dtype(device=None, model_params=0):
if args.bf16_unet:
return torch.bfloat16
if should_use_fp16(device=device, model_params=model_params):
return torch.float16
return torch.float32
def text_encoder_offload_device(): def text_encoder_offload_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
@ -656,7 +667,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
return False return False
#FP16 is just broken on these cards #FP16 is just broken on these cards
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX"] nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"]
for x in nvidia_16_series: for x in nvidia_16_series:
if x in props.name: if x in props.name:
return False return False

View File

@ -107,6 +107,10 @@ class ModelPatcher:
for k in patch_list: for k in patch_list:
if hasattr(patch_list[k], "to"): if hasattr(patch_list[k], "to"):
patch_list[k] = patch_list[k].to(device) patch_list[k] = patch_list[k].to(device)
if "unet_wrapper_function" in self.model_options:
wrap_func = self.model_options["unet_wrapper_function"]
if hasattr(wrap_func, "to"):
self.model_options["unet_wrapper_function"] = wrap_func.to(device)
def model_dtype(self): def model_dtype(self):
if hasattr(self.model, "get_dtype"): if hasattr(self.model, "get_dtype"):

View File

@ -4,7 +4,7 @@ import math
from comfy import model_management from comfy import model_management
from .ldm.util import instantiate_from_config from .ldm.util import instantiate_from_config
from .ldm.models.autoencoder import AutoencoderKL from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
import yaml import yaml
import comfy.utils import comfy.utils
@ -140,21 +140,24 @@ class CLIP:
return self.patcher.get_key_patches() return self.patcher.get_key_patches()
class VAE: class VAE:
def __init__(self, ckpt_path=None, device=None, config=None): def __init__(self, sd=None, device=None, config=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
if config is None: if config is None:
#default SD1.x/SD2.x VAE parameters #default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss") self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else: else:
self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval() self.first_stage_model = self.first_stage_model.eval()
if ckpt_path is not None:
sd = comfy.utils.load_torch_file(ckpt_path) m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format if len(m) > 0:
sd = diffusers_convert.convert_vae_state_dict(sd) print("Missing VAE keys", m)
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if len(m) > 0: if len(u) > 0:
print("Missing VAE keys", m) print("Leftover VAE keys", u)
if device is None: if device is None:
device = model_management.vae_device() device = model_management.vae_device()
@ -183,7 +186,7 @@ class VAE:
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).sample().float() encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
@ -229,7 +232,7 @@ class VAE:
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
for x in range(0, pixel_samples.shape[0], batch_number): for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu().float() samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
@ -327,7 +330,9 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if "params" in model_config_params["unet_config"]: if "params" in model_config_params["unet_config"]:
unet_config = model_config_params["unet_config"]["params"] unet_config = model_config_params["unet_config"]["params"]
if "use_fp16" in unet_config: if "use_fp16" in unet_config:
fp16 = unet_config["use_fp16"] fp16 = unet_config.pop("use_fp16")
if fp16:
unet_config["dtype"] = torch.float16
noise_aug_config = None noise_aug_config = None
if "noise_aug_config" in model_config_params: if "noise_aug_config" in model_config_params:
@ -373,10 +378,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
model.load_model_weights(state_dict, "model.diffusion_model.") model.load_model_weights(state_dict, "model.diffusion_model.")
if output_vae: if output_vae:
w = WeightsLoader() vae_sd = comfy.utils.state_dict_prefix_replace(state_dict, {"first_stage_model.": ""}, filter_keys=True)
vae = VAE(config=vae_config) vae = VAE(sd=vae_sd, config=vae_config)
w.first_stage_model = vae.first_stage_model
load_model_weights(w, state_dict)
if output_clip: if output_clip:
w = WeightsLoader() w = WeightsLoader()
@ -405,12 +408,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clip_target = None clip_target = None
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
fp16 = model_management.should_use_fp16(model_params=parameters) unet_dtype = model_management.unet_dtype(model_params=parameters)
class WeightsLoader(torch.nn.Module): class WeightsLoader(torch.nn.Module):
pass pass
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", fp16) model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
if model_config is None: if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
@ -418,21 +421,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_clipvision: if output_clipvision:
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True) clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
dtype = torch.float32
if fp16:
dtype = torch.float16
if output_model: if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, dtype) inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device) model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
model.load_model_weights(sd, "model.diffusion_model.") model.load_model_weights(sd, "model.diffusion_model.")
if output_vae: if output_vae:
vae = VAE() vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
w = WeightsLoader() vae = VAE(sd=vae_sd)
w.first_stage_model = vae.first_stage_model
load_model_weights(w, sd)
if output_clip: if output_clip:
w = WeightsLoader() w = WeightsLoader()
@ -458,15 +455,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
def load_unet(unet_path): #load unet in diffusers format def load_unet(unet_path): #load unet in diffusers format
sd = comfy.utils.load_torch_file(unet_path) sd = comfy.utils.load_torch_file(unet_path)
parameters = comfy.utils.calculate_parameters(sd) parameters = comfy.utils.calculate_parameters(sd)
fp16 = model_management.should_use_fp16(model_params=parameters) unet_dtype = model_management.unet_dtype(model_params=parameters)
if "input_blocks.0.0.weight" in sd: #ldm if "input_blocks.0.0.weight" in sd: #ldm
model_config = model_detection.model_config_from_unet(sd, "", fp16) model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
if model_config is None: if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
new_sd = sd new_sd = sd
else: #diffusers else: #diffusers
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
if model_config is None: if model_config is None:
print("ERROR UNSUPPORTED UNET", unet_path) print("ERROR UNSUPPORTED UNET", unet_path)
return None return None

View File

@ -6,6 +6,8 @@ Tiny AutoEncoder for Stable Diffusion
import torch import torch
import torch.nn as nn import torch.nn as nn
import comfy.utils
def conv(n_in, n_out, **kwargs): def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
@ -50,9 +52,9 @@ class TAESD(nn.Module):
self.encoder = Encoder() self.encoder = Encoder()
self.decoder = Decoder() self.decoder = Decoder()
if encoder_path is not None: if encoder_path is not None:
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True)) self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
if decoder_path is not None: if decoder_path is not None:
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True)) self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
@staticmethod @staticmethod
def scale_latents(x): def scale_latents(x):

View File

@ -47,12 +47,17 @@ def state_dict_key_replace(state_dict, keys_to_replace):
state_dict[keys_to_replace[x]] = state_dict.pop(x) state_dict[keys_to_replace[x]] = state_dict.pop(x)
return state_dict return state_dict
def state_dict_prefix_replace(state_dict, replace_prefix): def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
if filter_keys:
out = {}
else:
out = state_dict
for rp in replace_prefix: for rp in replace_prefix:
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
for x in replace: for x in replace:
state_dict[x[1]] = state_dict.pop(x[0]) w = state_dict.pop(x[0])
return state_dict out[x[1]] = w
return out
def transformers_convert(sd, prefix_from, prefix_to, number): def transformers_convert(sd, prefix_from, prefix_to, number):
@ -408,6 +413,10 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
output[b:b+1] = out/out_div output[b:b+1] = out/out_div
return output return output
PROGRESS_BAR_ENABLED = True
def set_progress_bar_enabled(enabled):
global PROGRESS_BAR_ENABLED
PROGRESS_BAR_ENABLED = enabled
PROGRESS_BAR_HOOK = None PROGRESS_BAR_HOOK = None
def set_progress_bar_global_hook(function): def set_progress_bar_global_hook(function):

View File

@ -3,6 +3,7 @@ import comfy.sample
from comfy.k_diffusion import sampling as k_diffusion_sampling from comfy.k_diffusion import sampling as k_diffusion_sampling
import latent_preview import latent_preview
import torch import torch
import comfy.utils
class BasicScheduler: class BasicScheduler:
@ -219,7 +220,7 @@ class SamplerCustom:
x0_output = {} x0_output = {}
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output) callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
disable_pbar = False disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
out = latent.copy() out = latent.copy()

View File

@ -19,6 +19,7 @@ def load_hypernetwork_patch(path, strength):
"tanh": torch.nn.Tanh, "tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid, "sigmoid": torch.nn.Sigmoid,
"softsign": torch.nn.Softsign, "softsign": torch.nn.Softsign,
"mish": torch.nn.Mish,
} }
if activation_func not in valid_activation: if activation_func not in valid_activation:
@ -42,7 +43,8 @@ def load_hypernetwork_patch(path, strength):
linears = list(map(lambda a: a[:-len(".weight")], linears)) linears = list(map(lambda a: a[:-len(".weight")], linears))
layers = [] layers = []
for i in range(len(linears)): i = 0
while i < len(linears):
lin_name = linears[i] lin_name = linears[i]
last_layer = (i == (len(linears) - 1)) last_layer = (i == (len(linears) - 1))
penultimate_layer = (i == (len(linears) - 2)) penultimate_layer = (i == (len(linears) - 2))
@ -56,10 +58,17 @@ def load_hypernetwork_patch(path, strength):
if (not last_layer) or (activate_output): if (not last_layer) or (activate_output):
layers.append(valid_activation[activation_func]()) layers.append(valid_activation[activation_func]())
if is_layer_norm: if is_layer_norm:
layers.append(torch.nn.LayerNorm(lin_weight.shape[0])) i += 1
ln_name = linears[i]
ln_weight = attn_weights['{}.weight'.format(ln_name)]
ln_bias = attn_weights['{}.bias'.format(ln_name)]
ln = torch.nn.LayerNorm(ln_weight.shape[0])
ln.load_state_dict({"weight": ln_weight, "bias": ln_bias})
layers.append(ln)
if use_dropout: if use_dropout:
if (not last_layer) and (not penultimate_layer or last_layer_dropout): if (not last_layer) and (not penultimate_layer or last_layer_dropout):
layers.append(torch.nn.Dropout(p=0.3)) layers.append(torch.nn.Dropout(p=0.3))
i += 1
output.append(torch.nn.Sequential(*layers)) output.append(torch.nn.Sequential(*layers))
out[dim] = torch.nn.ModuleList(output) out[dim] = torch.nn.ModuleList(output)

View File

@ -240,8 +240,8 @@ class MaskComposite:
right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2])) right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2]))
visible_width, visible_height = (right - left, bottom - top,) visible_width, visible_height = (right - left, bottom - top,)
source_portion = source[:visible_height, :visible_width] source_portion = source[:, :visible_height, :visible_width]
destination_portion = destination[top:bottom, left:right] destination_portion = destination[:, top:bottom, left:right]
if operation == "multiply": if operation == "multiply":
output[:, top:bottom, left:right] = destination_portion * source_portion output[:, top:bottom, left:right] = destination_portion * source_portion
@ -282,10 +282,10 @@ class FeatherMask:
def feather(self, mask, left, top, right, bottom): def feather(self, mask, left, top, right, bottom):
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone() output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
left = min(left, output.shape[1]) left = min(left, output.shape[-1])
right = min(right, output.shape[1]) right = min(right, output.shape[-1])
top = min(top, output.shape[0]) top = min(top, output.shape[-2])
bottom = min(bottom, output.shape[0]) bottom = min(bottom, output.shape[-2])
for x in range(left): for x in range(left):
feather_rate = (x + 1.0) / left feather_rate = (x + 1.0) / left

View File

@ -1,6 +1,7 @@
import comfy.sd import comfy.sd
import comfy.utils import comfy.utils
import comfy.model_base import comfy.model_base
import comfy.model_management
import folder_paths import folder_paths
import json import json
@ -178,6 +179,95 @@ class CheckpointSave:
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata) comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
return {} return {}
class CLIPSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip": ("CLIP",),
"filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None):
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {}
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
comfy.model_management.load_models_gpu([clip.load_model()])
clip_sd = clip.get_sd()
for prefix in ["clip_l.", "clip_g.", ""]:
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
current_clip_sd = {}
for x in k:
current_clip_sd[x] = clip_sd.pop(x)
if len(current_clip_sd) == 0:
continue
p = prefix[:-1]
replace_prefix = {}
filename_prefix_ = filename_prefix
if len(p) > 0:
filename_prefix_ = "{}_{}".format(filename_prefix_, p)
replace_prefix[prefix] = ""
replace_prefix["transformer."] = ""
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir)
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
return {}
class VAESave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)
metadata = {}
if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
return {}
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ModelMergeSimple": ModelMergeSimple, "ModelMergeSimple": ModelMergeSimple,
@ -186,4 +276,6 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeAdd": ModelAdd, "ModelMergeAdd": ModelAdd,
"CheckpointSave": CheckpointSave, "CheckpointSave": CheckpointSave,
"CLIPMergeSimple": CLIPMergeSimple, "CLIPMergeSimple": CLIPMergeSimple,
"CLIPSave": CLIPSave,
"VAESave": VAESave,
} }

View File

@ -2,6 +2,7 @@ import os
import sys import sys
import copy import copy
import json import json
import logging
import threading import threading
import heapq import heapq
import traceback import traceback
@ -156,7 +157,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
if server.client_id is not None: if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
except comfy.model_management.InterruptProcessingException as iex: except comfy.model_management.InterruptProcessingException as iex:
print("Processing interrupted") logging.info("Processing interrupted")
# skip formatting inputs/outputs # skip formatting inputs/outputs
error_details = { error_details = {
@ -177,8 +178,8 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
for node_id, node_outputs in outputs.items(): for node_id, node_outputs in outputs.items():
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
print("!!! Exception during processing !!!") logging.error("!!! Exception during processing !!!")
print(traceback.format_exc()) logging.error(traceback.format_exc())
error_details = { error_details = {
"node_id": unique_id, "node_id": unique_id,
@ -636,11 +637,11 @@ def validate_prompt(prompt):
if valid is True: if valid is True:
good_outputs.add(o) good_outputs.add(o)
else: else:
print(f"Failed to validate prompt for output {o}:") logging.error(f"Failed to validate prompt for output {o}:")
if len(reasons) > 0: if len(reasons) > 0:
print("* (prompt):") logging.error("* (prompt):")
for reason in reasons: for reason in reasons:
print(f" - {reason['message']}: {reason['details']}") logging.error(f" - {reason['message']}: {reason['details']}")
errors += [(o, reasons)] errors += [(o, reasons)]
for node_id, result in validated.items(): for node_id, result in validated.items():
valid = result[0] valid = result[0]
@ -656,11 +657,11 @@ def validate_prompt(prompt):
"dependent_outputs": [], "dependent_outputs": [],
"class_type": class_type "class_type": class_type
} }
print(f"* {class_type} {node_id}:") logging.error(f"* {class_type} {node_id}:")
for reason in reasons: for reason in reasons:
print(f" - {reason['message']}: {reason['details']}") logging.error(f" - {reason['message']}: {reason['details']}")
node_errors[node_id]["dependent_outputs"].append(o) node_errors[node_id]["dependent_outputs"].append(o)
print("Output will be ignored") logging.error("Output will be ignored")
if len(good_outputs) == 0: if len(good_outputs) == 0:
errors_list = [] errors_list = []

View File

@ -1,5 +1,6 @@
#Rename this to extra_model_paths.yaml and ComfyUI will load it #Rename this to extra_model_paths.yaml and ComfyUI will load it
#config for a1111 ui #config for a1111 ui
#all you have to do is change the base_path to where yours is installed #all you have to do is change the base_path to where yours is installed
a111: a111:
@ -19,6 +20,21 @@ a111:
hypernetworks: models/hypernetworks hypernetworks: models/hypernetworks
controlnet: models/ControlNet controlnet: models/ControlNet
#config for comfyui
#your base path should be either an existing comfy install or a central folder where you store all of your models, loras, etc.
#comfyui:
# base_path: path/to/comfyui/
# checkpoints: models/checkpoints/
# clip: models/clip/
# clip_vision: models/clip_vision/
# configs: models/configs/
# controlnet: models/controlnet/
# embeddings: models/embeddings/
# loras: models/loras/
# upscale_models: models/upscale_models/
# vae: models/vae/
#other_ui: #other_ui:
# base_path: path/to/ui # base_path: path/to/ui
# checkpoints: models/checkpoints # checkpoints: models/checkpoints

View File

@ -29,6 +29,8 @@ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
@ -46,6 +48,10 @@ def set_temp_directory(temp_dir):
global temp_directory global temp_directory
temp_directory = temp_dir temp_directory = temp_dir
def set_input_directory(input_dir):
global input_directory
input_directory = input_dir
def get_output_directory(): def get_output_directory():
global output_directory global output_directory
return output_directory return output_directory
@ -140,7 +146,7 @@ def recursive_search(directory, excluded_dir_names=None):
return result, dirs return result, dirs
def filter_files_extensions(files, extensions): def filter_files_extensions(files, extensions):
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files))) return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))

View File

@ -56,7 +56,12 @@ def get_previewer(device, latent_format):
# TODO previewer methods # TODO previewer methods
taesd_decoder_path = None taesd_decoder_path = None
if latent_format.taesd_decoder_name is not None: if latent_format.taesd_decoder_name is not None:
taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name) taesd_decoder_path = next(
(fn for fn in folder_paths.get_filename_list("vae_approx")
if fn.startswith(latent_format.taesd_decoder_name)),
""
)
taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path)
if method == LatentPreviewMethod.Auto: if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB method = LatentPreviewMethod.Latent2RGB

10
main.py
View File

@ -175,6 +175,16 @@ if __name__ == "__main__":
print(f"Setting output directory to: {output_dir}") print(f"Setting output directory to: {output_dir}")
folder_paths.set_output_directory(output_dir) folder_paths.set_output_directory(output_dir)
#These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
if args.input_directory:
input_dir = os.path.abspath(args.input_directory)
print(f"Setting input directory to: {input_dir}")
folder_paths.set_input_directory(input_dir)
if args.quick_test_for_ci: if args.quick_test_for_ci:
exit(0) exit(0)

View File

@ -584,7 +584,8 @@ class VAELoader:
#TODO: scale factor? #TODO: scale factor?
def load_vae(self, vae_name): def load_vae(self, vae_name):
vae_path = folder_paths.get_full_path("vae", vae_name) vae_path = folder_paths.get_full_path("vae", vae_name)
vae = comfy.sd.VAE(ckpt_path=vae_path) sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
return (vae,) return (vae,)
class ControlNetLoader: class ControlNetLoader:
@ -1202,7 +1203,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
noise_mask = latent["noise_mask"] noise_mask = latent["noise_mask"]
callback = latent_preview.prepare_callback(model, steps) callback = latent_preview.prepare_callback(model, steps)
disable_pbar = False disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
@ -1660,7 +1661,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"KSampler": "KSampler", "KSampler": "KSampler",
"KSamplerAdvanced": "KSampler (Advanced)", "KSamplerAdvanced": "KSampler (Advanced)",
# Loaders # Loaders
"CheckpointLoader": "Load Checkpoint (With Config)", "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
"CheckpointLoaderSimple": "Load Checkpoint", "CheckpointLoaderSimple": "Load Checkpoint",
"VAELoader": "Load VAE", "VAELoader": "Load VAE",
"LoraLoader": "Load LoRA", "LoraLoader": "Load LoRA",

View File

@ -47,7 +47,7 @@
" !git pull\n", " !git pull\n",
"\n", "\n",
"!echo -= Install dependencies =-\n", "!echo -= Install dependencies =-\n",
"!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117" "!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117"
] ]
}, },
{ {

View File

@ -5,6 +5,61 @@ function setNodeMode(node, mode) {
node.graph.change(); node.graph.change();
} }
function addNodesToGroup(group, nodes=[]) {
var x1, y1, x2, y2;
var nx1, ny1, nx2, ny2;
var node;
x1 = y1 = x2 = y2 = -1;
nx1 = ny1 = nx2 = ny2 = -1;
for (var n of [group._nodes, nodes]) {
for (var i in n) {
node = n[i]
nx1 = node.pos[0]
ny1 = node.pos[1]
nx2 = node.pos[0] + node.size[0]
ny2 = node.pos[1] + node.size[1]
if (node.type != "Reroute") {
ny1 -= LiteGraph.NODE_TITLE_HEIGHT;
}
if (node.flags?.collapsed) {
ny2 = ny1 + LiteGraph.NODE_TITLE_HEIGHT;
if (node?._collapsed_width) {
nx2 = nx1 + Math.round(node._collapsed_width);
}
}
if (x1 == -1 || nx1 < x1) {
x1 = nx1;
}
if (y1 == -1 || ny1 < y1) {
y1 = ny1;
}
if (x2 == -1 || nx2 > x2) {
x2 = nx2;
}
if (y2 == -1 || ny2 > y2) {
y2 = ny2;
}
}
}
var padding = 10;
y1 = y1 - Math.round(group.font_size * 1.4);
group.pos = [x1 - padding, y1 - padding];
group.size = [x2 - x1 + padding * 2, y2 - y1 + padding * 2];
}
app.registerExtension({ app.registerExtension({
name: "Comfy.GroupOptions", name: "Comfy.GroupOptions",
setup() { setup() {
@ -14,6 +69,17 @@ app.registerExtension({
const options = orig.apply(this, arguments); const options = orig.apply(this, arguments);
const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]); const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]);
if (!group) { if (!group) {
options.push({
content: "Add Group For Selected Nodes",
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
callback: () => {
var group = new LiteGraph.LGraphGroup();
addNodesToGroup(group, this.selected_nodes)
app.canvas.graph.add(group);
this.graph.change();
}
});
return options; return options;
} }
@ -21,6 +87,15 @@ app.registerExtension({
group.recomputeInsideNodes(); group.recomputeInsideNodes();
const nodesInGroup = group._nodes; const nodesInGroup = group._nodes;
options.push({
content: "Add Selected Nodes To Group",
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
callback: () => {
addNodesToGroup(group, this.selected_nodes)
this.graph.change();
}
});
// No nodes in group, return default options // No nodes in group, return default options
if (nodesInGroup.length === 0) { if (nodesInGroup.length === 0) {
return options; return options;
@ -38,6 +113,23 @@ app.registerExtension({
} }
} }
options.push({
content: "Fit Group To Nodes",
callback: () => {
addNodesToGroup(group)
this.graph.change();
}
});
options.push({
content: "Select Nodes",
callback: () => {
this.selectNodes(nodesInGroup);
this.graph.change();
this.canvas.focus();
}
});
// Modes // Modes
// 0: Always // 0: Always
// 1: On Event // 1: On Event

View File

@ -3796,7 +3796,7 @@
out = out || new Float32Array(4); out = out || new Float32Array(4);
out[0] = this.pos[0] - 4; out[0] = this.pos[0] - 4;
out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT; out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT;
out[2] = this.size[0] + 4; out[2] = this.flags.collapsed ? (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) : this.size[0] + 4;
out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT; out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT;
if (this.onBounding) { if (this.onBounding) {

View File

@ -928,6 +928,16 @@ export class ComfyApp {
block_default = true; block_default = true;
} }
// Alt + C collapse/uncollapse
if (e.key === 'c' && e.altKey) {
if (this.selected_nodes) {
for (var i in this.selected_nodes) {
this.selected_nodes[i].collapse()
}
}
block_default = true;
}
// Ctrl+C Copy // Ctrl+C Copy
if ((e.key === 'c') && (e.metaKey || e.ctrlKey)) { if ((e.key === 'c') && (e.metaKey || e.ctrlKey)) {
// Trigger onCopy // Trigger onCopy
@ -1592,7 +1602,7 @@ export class ComfyApp {
all_inputs = all_inputs.concat(Object.keys(parent.inputs)) all_inputs = all_inputs.concat(Object.keys(parent.inputs))
for (let parent_input in all_inputs) { for (let parent_input in all_inputs) {
parent_input = all_inputs[parent_input]; parent_input = all_inputs[parent_input];
if (parent.inputs[parent_input].type === node.inputs[i].type) { if (parent.inputs[parent_input]?.type === node.inputs[i].type) {
link = parent.getInputLink(parent_input); link = parent.getInputLink(parent_input);
if (link) { if (link) {
parent = parent.getInputNode(parent_input); parent = parent.getInputNode(parent_input);

View File

@ -809,7 +809,8 @@ export class ComfyUI {
if ( if (
this.lastQueueSize != 0 && this.lastQueueSize != 0 &&
status.exec_info.queue_remaining == 0 && status.exec_info.queue_remaining == 0 &&
document.getElementById("autoQueueCheckbox").checked document.getElementById("autoQueueCheckbox").checked &&
! app.lastExecutionError
) { ) {
app.queuePrompt(0, this.batchCount); app.queuePrompt(0, this.batchCount);
} }