mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-11 14:02:37 +08:00
Merge branch 'comfyanonymous:master' into feat/widgetFeedback
This commit is contained in:
commit
d7ed61d4cb
@ -46,6 +46,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
| Ctrl + S | Save workflow |
|
||||
| Ctrl + O | Load workflow |
|
||||
| Ctrl + A | Select all nodes |
|
||||
| Alt + C | Collapse/uncollapse selected nodes |
|
||||
| Ctrl + M | Mute/unmute selected nodes |
|
||||
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
||||
| Delete/Backspace | Delete selected nodes |
|
||||
@ -89,6 +90,8 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
||||
|
||||
Put your VAE in: models/vae
|
||||
|
||||
Note: pytorch does not support python 3.12 yet so make sure your python version is 3.11 or earlier.
|
||||
|
||||
### AMD GPUs (Linux only)
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
|
||||
@ -34,8 +34,7 @@ class ControlNet(nn.Module):
|
||||
dims=2,
|
||||
num_classes=None,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
use_bf16=False,
|
||||
dtype=torch.float32,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
@ -108,8 +107,7 @@ class ControlNet(nn.Module):
|
||||
self.conv_resample = conv_resample
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.dtype = th.bfloat16 if use_bf16 else self.dtype
|
||||
self.dtype = dtype
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
|
||||
@ -39,6 +39,7 @@ parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORI
|
||||
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
|
||||
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
|
||||
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.")
|
||||
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
|
||||
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
|
||||
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
|
||||
@ -52,6 +53,8 @@ fp_group = parser.add_mutually_exclusive_group()
|
||||
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||
|
||||
parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
||||
|
||||
fpvae_group = parser.add_mutually_exclusive_group()
|
||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
||||
|
||||
@ -292,8 +292,8 @@ def load_controlnet(ckpt_path, model=None):
|
||||
|
||||
controlnet_config = None
|
||||
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
||||
use_fp16 = comfy.model_management.should_use_fp16()
|
||||
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16)
|
||||
unet_dtype = comfy.model_management.unet_dtype()
|
||||
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
|
||||
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
|
||||
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
||||
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
||||
@ -353,8 +353,8 @@ def load_controlnet(ckpt_path, model=None):
|
||||
return net
|
||||
|
||||
if controlnet_config is None:
|
||||
use_fp16 = comfy.model_management.should_use_fp16()
|
||||
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16, True).unet_config
|
||||
unet_dtype = comfy.model_management.unet_dtype()
|
||||
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||
@ -383,8 +383,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
||||
print(missing, unexpected)
|
||||
|
||||
if use_fp16:
|
||||
control_model = control_model.half()
|
||||
control_model = control_model.to(unet_dtype)
|
||||
|
||||
global_average_pooling = False
|
||||
filename = os.path.splitext(ckpt_path)[0]
|
||||
|
||||
@ -31,6 +31,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
|
||||
|
||||
vae = None
|
||||
if output_vae:
|
||||
vae = comfy.sd.VAE(ckpt_path=vae_path)
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
vae = comfy.sd.VAE(sd=sd)
|
||||
|
||||
return (unet, clip, vae)
|
||||
|
||||
@ -20,7 +20,7 @@ class SD15(LatentFormat):
|
||||
[-0.2829, 0.1762, 0.2721],
|
||||
[-0.2120, -0.2616, -0.7177]
|
||||
]
|
||||
self.taesd_decoder_name = "taesd_decoder.pth"
|
||||
self.taesd_decoder_name = "taesd_decoder"
|
||||
|
||||
class SDXL(LatentFormat):
|
||||
def __init__(self):
|
||||
@ -32,4 +32,4 @@ class SDXL(LatentFormat):
|
||||
[ 0.0568, 0.1687, -0.0755],
|
||||
[-0.3112, -0.2359, -0.2076]
|
||||
]
|
||||
self.taesd_decoder_name = "taesdxl_decoder.pth"
|
||||
self.taesd_decoder_name = "taesdxl_decoder"
|
||||
|
||||
@ -2,67 +2,66 @@ import torch
|
||||
# import pytorch_lightning as pl
|
||||
import torch.nn.functional as F
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
from comfy.ldm.util import instantiate_from_config
|
||||
from comfy.ldm.modules.ema import LitEma
|
||||
|
||||
# class AutoencoderKL(pl.LightningModule):
|
||||
class AutoencoderKL(torch.nn.Module):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
ema_decay=None,
|
||||
learn_logvar=False
|
||||
):
|
||||
class DiagonalGaussianRegularizer(torch.nn.Module):
|
||||
def __init__(self, sample: bool = True):
|
||||
super().__init__()
|
||||
self.learn_logvar = learn_logvar
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
assert ddconfig["double_z"]
|
||||
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
self.sample = sample
|
||||
|
||||
def get_trainable_parameters(self) -> Any:
|
||||
yield from ()
|
||||
|
||||
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||
log = dict()
|
||||
posterior = DiagonalGaussianDistribution(z)
|
||||
if self.sample:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
kl_loss = posterior.kl()
|
||||
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
||||
log["kl_loss"] = kl_loss
|
||||
return z, log
|
||||
|
||||
|
||||
class AbstractAutoencoder(torch.nn.Module):
|
||||
"""
|
||||
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
|
||||
unCLIP models, etc. Hence, it is fairly general, and specific features
|
||||
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ema_decay: Union[None, float] = None,
|
||||
monitor: Union[None, str] = None,
|
||||
input_key: str = "jpg",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_key = input_key
|
||||
self.use_ema = ema_decay is not None
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
|
||||
self.use_ema = ema_decay is not None
|
||||
if self.use_ema:
|
||||
self.ema_decay = ema_decay
|
||||
assert 0. < ema_decay < 1.
|
||||
self.model_ema = LitEma(self, decay=ema_decay)
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
def get_input(self, batch) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
if path.lower().endswith(".safetensors"):
|
||||
import safetensors.torch
|
||||
sd = safetensors.torch.load_file(path, device="cpu")
|
||||
else:
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path}")
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
# for EMA computation
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
@contextmanager
|
||||
def ema_scope(self, context=None):
|
||||
@ -70,154 +69,159 @@ class AutoencoderKL(torch.nn.Module):
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
logpy.info(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
logpy.info(f"{context}: Restored training weights")
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
def encode(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError("encode()-method of abstract base class called")
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
def decode(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError("decode()-method of abstract base class called")
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
return get_obj_from_str(cfg["target"])(
|
||||
params, lr=lr, **cfg.get("params", dict())
|
||||
)
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
def configure_optimizers(self) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
class AutoencodingEngine(AbstractAutoencoder):
|
||||
"""
|
||||
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
|
||||
(we also restore them explicitly as special cases for legacy reasons).
|
||||
Regularizations such as KL or VQ are moved to the regularizer class.
|
||||
"""
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# train encoder+decoder+logvar
|
||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||
return aeloss
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
encoder_config: Dict,
|
||||
decoder_config: Dict,
|
||||
regularizer_config: Dict,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# train the discriminator
|
||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
|
||||
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, postfix=""):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val"+postfix)
|
||||
|
||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val"+postfix)
|
||||
|
||||
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
|
||||
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
|
||||
if self.learn_logvar:
|
||||
print(f"{self.__class__.__name__}: Learning logvar")
|
||||
ae_params_list.append(self.loss.logvar)
|
||||
opt_ae = torch.optim.Adam(ae_params_list,
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
return [opt_ae, opt_disc], []
|
||||
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
||||
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
||||
self.regularization: AbstractRegularizer = instantiate_from_config(
|
||||
regularizer_config
|
||||
)
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
return self.decoder.get_last_layer()
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log["reconstructions"] = xrec
|
||||
if log_ema or self.use_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, posterior_ema = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec_ema.shape[1] > 3
|
||||
xrec_ema = self.to_rgb(xrec_ema)
|
||||
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
|
||||
log["reconstructions_ema"] = xrec_ema
|
||||
log["inputs"] = x
|
||||
return log
|
||||
def encode(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
return_reg_log: bool = False,
|
||||
unregularized: bool = False,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
z = self.encoder(x)
|
||||
if unregularized:
|
||||
return z, dict()
|
||||
z, reg_log = self.regularization(z)
|
||||
if return_reg_log:
|
||||
return z, reg_log
|
||||
return z
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
x = self.decoder(z, **kwargs)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, **additional_decode_kwargs
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
z, reg_log = self.encode(x, return_reg_log=True)
|
||||
dec = self.decode(z, **additional_decode_kwargs)
|
||||
return z, dec, reg_log
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
def __init__(self, embed_dim: int, **kwargs):
|
||||
self.max_batch_size = kwargs.pop("max_batch_size", None)
|
||||
ddconfig = kwargs.pop("ddconfig")
|
||||
super().__init__(
|
||||
encoder_config={
|
||||
"target": "comfy.ldm.modules.diffusionmodules.model.Encoder",
|
||||
"params": ddconfig,
|
||||
},
|
||||
decoder_config={
|
||||
"target": "comfy.ldm.modules.diffusionmodules.model.Decoder",
|
||||
"params": ddconfig,
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
self.quant_conv = torch.nn.Conv2d(
|
||||
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
||||
(1 + ddconfig["double_z"]) * embed_dim,
|
||||
1,
|
||||
)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
def get_autoencoder_params(self) -> list:
|
||||
params = super().get_autoencoder_params()
|
||||
return params
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_reg_log: bool = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
if self.max_batch_size is None:
|
||||
z = self.encoder(x)
|
||||
z = self.quant_conv(z)
|
||||
else:
|
||||
N = x.shape[0]
|
||||
bs = self.max_batch_size
|
||||
n_batches = int(math.ceil(N / bs))
|
||||
z = list()
|
||||
for i_batch in range(n_batches):
|
||||
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
|
||||
z_batch = self.quant_conv(z_batch)
|
||||
z.append(z_batch)
|
||||
z = torch.cat(z, 0)
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
||||
z, reg_log = self.regularization(z)
|
||||
if return_reg_log:
|
||||
return z, reg_log
|
||||
return z
|
||||
|
||||
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
||||
if self.max_batch_size is None:
|
||||
dec = self.post_quant_conv(z)
|
||||
dec = self.decoder(dec, **decoder_kwargs)
|
||||
else:
|
||||
N = z.shape[0]
|
||||
bs = self.max_batch_size
|
||||
n_batches = int(math.ceil(N / bs))
|
||||
dec = list()
|
||||
for i_batch in range(n_batches):
|
||||
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
|
||||
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
|
||||
dec.append(dec_batch)
|
||||
dec = torch.cat(dec, 0)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
class AutoencoderKL(AutoencodingEngineLegacy):
|
||||
def __init__(self, **kwargs):
|
||||
if "lossconfig" in kwargs:
|
||||
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
||||
super().__init__(
|
||||
regularizer_config={
|
||||
"target": (
|
||||
"comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"
|
||||
)
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -94,253 +94,222 @@ def zero_module(module):
|
||||
def Normalize(in_channels, dtype=None, device=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
|
||||
def attention_basic(q, k, v, heads, mask=None):
|
||||
h = heads
|
||||
scale = (q.shape[-1] // heads) ** -0.5
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
q, k = q.float(), k.float()
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||
else:
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
del q, k
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, 'b c h w -> b c (h w)')
|
||||
w_ = rearrange(w_, 'b i j -> b j i')
|
||||
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return out
|
||||
|
||||
|
||||
class CrossAttentionBirchSan(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
def attention_sub_quad(query, key, value, heads, mask=None):
|
||||
scale = (query.shape[-1] // heads) ** -0.5
|
||||
query = query.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
key_t = key.transpose(1,2).unflatten(1, (heads, -1)).flatten(end_dim=1)
|
||||
del key
|
||||
value = value.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
dtype = query.dtype
|
||||
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
|
||||
if upcast_attention:
|
||||
bytes_per_token = torch.finfo(torch.float32).bits//8
|
||||
else:
|
||||
bytes_per_token = torch.finfo(query.dtype).bits//8
|
||||
batch_x_heads, q_tokens, _ = query.shape
|
||||
_, _, k_tokens = key_t.shape
|
||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
|
||||
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
h = self.heads
|
||||
kv_chunk_size_min = None
|
||||
|
||||
query = self.to_q(x)
|
||||
context = default(context, x)
|
||||
key = self.to_k(context)
|
||||
if value is not None:
|
||||
value = self.to_v(value)
|
||||
else:
|
||||
value = self.to_v(context)
|
||||
#not sure at all about the math here
|
||||
#TODO: tweak this
|
||||
if mem_free_total > 8192 * 1024 * 1024 * 1.3:
|
||||
query_chunk_size_x = 1024 * 4
|
||||
elif mem_free_total > 4096 * 1024 * 1024 * 1.3:
|
||||
query_chunk_size_x = 1024 * 2
|
||||
else:
|
||||
query_chunk_size_x = 1024
|
||||
kv_chunk_size_min_x = None
|
||||
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
|
||||
if kv_chunk_size_x < 1024:
|
||||
kv_chunk_size_x = None
|
||||
|
||||
del context, x
|
||||
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
||||
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
||||
# i.e. send it down the unchunked fast-path
|
||||
query_chunk_size = q_tokens
|
||||
kv_chunk_size = k_tokens
|
||||
else:
|
||||
query_chunk_size = query_chunk_size_x
|
||||
kv_chunk_size = kv_chunk_size_x
|
||||
kv_chunk_size_min = kv_chunk_size_min_x
|
||||
|
||||
query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
key_t = key.transpose(1,2).unflatten(1, (self.heads, -1)).flatten(end_dim=1)
|
||||
del key
|
||||
value = value.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
hidden_states = efficient_dot_product_attention(
|
||||
query,
|
||||
key_t,
|
||||
value,
|
||||
query_chunk_size=query_chunk_size,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
kv_chunk_size_min=kv_chunk_size_min,
|
||||
use_checkpoint=False,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
dtype = query.dtype
|
||||
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
|
||||
if upcast_attention:
|
||||
bytes_per_token = torch.finfo(torch.float32).bits//8
|
||||
else:
|
||||
bytes_per_token = torch.finfo(query.dtype).bits//8
|
||||
batch_x_heads, q_tokens, _ = query.shape
|
||||
_, _, k_tokens = key_t.shape
|
||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
|
||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||
return hidden_states
|
||||
|
||||
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
|
||||
def attention_split(q, k, v, heads, mask=None):
|
||||
scale = (q.shape[-1] // heads) ** -0.5
|
||||
h = heads
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
kv_chunk_size_min = None
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
#not sure at all about the math here
|
||||
#TODO: tweak this
|
||||
if mem_free_total > 8192 * 1024 * 1024 * 1.3:
|
||||
query_chunk_size_x = 1024 * 4
|
||||
elif mem_free_total > 4096 * 1024 * 1024 * 1.3:
|
||||
query_chunk_size_x = 1024 * 2
|
||||
else:
|
||||
query_chunk_size_x = 1024
|
||||
kv_chunk_size_min_x = None
|
||||
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
|
||||
if kv_chunk_size_x < 1024:
|
||||
kv_chunk_size_x = None
|
||||
mem_free_total = model_management.get_free_memory(q.device)
|
||||
|
||||
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
||||
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
||||
# i.e. send it down the unchunked fast-path
|
||||
query_chunk_size = q_tokens
|
||||
kv_chunk_size = k_tokens
|
||||
else:
|
||||
query_chunk_size = query_chunk_size_x
|
||||
kv_chunk_size = kv_chunk_size_x
|
||||
kv_chunk_size_min = kv_chunk_size_min_x
|
||||
|
||||
hidden_states = efficient_dot_product_attention(
|
||||
query,
|
||||
key_t,
|
||||
value,
|
||||
query_chunk_size=query_chunk_size,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
kv_chunk_size_min=kv_chunk_size_min,
|
||||
use_checkpoint=self.training,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
hidden_states = hidden_states.unflatten(0, (-1, self.heads)).transpose(1,2).flatten(start_dim=2)
|
||||
|
||||
out_proj, dropout = self.to_out
|
||||
hidden_states = out_proj(hidden_states)
|
||||
hidden_states = dropout(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
|
||||
class CrossAttentionDoggettx(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k_in = self.to_k(context)
|
||||
if value is not None:
|
||||
v_in = self.to_v(value)
|
||||
del value
|
||||
else:
|
||||
v_in = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
mem_free_total = model_management.get_free_memory(q.device)
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||
|
||||
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
|
||||
first_op_done = False
|
||||
cleared_cache = False
|
||||
while True:
|
||||
try:
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * self.scale
|
||||
else:
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
first_op_done = True
|
||||
|
||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||
del s1
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
break
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
if first_op_done == False:
|
||||
model_management.soft_empty_cache(True)
|
||||
if cleared_cache == False:
|
||||
cleared_cache = True
|
||||
print("out of memory error, emptying cache and trying again")
|
||||
continue
|
||||
steps *= 2
|
||||
if steps > 64:
|
||||
raise e
|
||||
print("out of memory error, increasing steps and trying again", steps)
|
||||
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
|
||||
first_op_done = False
|
||||
cleared_cache = False
|
||||
while True:
|
||||
try:
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
|
||||
else:
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
|
||||
first_op_done = True
|
||||
|
||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||
del s1
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
break
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
if first_op_done == False:
|
||||
model_management.soft_empty_cache(True)
|
||||
if cleared_cache == False:
|
||||
cleared_cache = True
|
||||
print("out of memory error, emptying cache and trying again")
|
||||
continue
|
||||
steps *= 2
|
||||
if steps > 64:
|
||||
raise e
|
||||
print("out of memory error, increasing steps and trying again", steps)
|
||||
else:
|
||||
raise e
|
||||
|
||||
del q, k, v
|
||||
del q, k, v
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
return r2
|
||||
|
||||
return self.to_out(r2)
|
||||
def attention_xformers(q, k, v, heads, mask=None):
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, t.shape[1], heads, -1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * heads, t.shape[1], -1)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# actually compute the attention, what we cannot get enough of
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, out.shape[1], -1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, out.shape[1], -1)
|
||||
)
|
||||
return out
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None):
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
optimized_attention = attention_basic
|
||||
optimized_attention_masked = attention_basic
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
print("Using xformers cross attention")
|
||||
optimized_attention = attention_xformers
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
print("Using pytorch cross attention")
|
||||
optimized_attention = attention_pytorch
|
||||
else:
|
||||
if args.use_split_cross_attention:
|
||||
print("Using split optimization for cross attention")
|
||||
optimized_attention = attention_split
|
||||
else:
|
||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
optimized_attention = attention_sub_quad
|
||||
|
||||
if model_management.pytorch_attention_enabled():
|
||||
optimized_attention_masked = attention_pytorch
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
||||
@ -348,62 +317,6 @@ class CrossAttention(nn.Module):
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
if value is not None:
|
||||
v = self.to_v(value)
|
||||
del value
|
||||
else:
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
q, k = q.float(), k.float()
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
else:
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
del q, k
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', sim, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
class MemoryEfficientCrossAttention(nn.Module):
|
||||
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=comfy.ops):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
@ -412,7 +325,6 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
@ -424,85 +336,12 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||
else:
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# actually compute the attention, what we cannot get enough of
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
||||
)
|
||||
return self.to_out(out)
|
||||
|
||||
class CrossAttentionPytorch(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x, context=None, value=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
if value is not None:
|
||||
v = self.to_v(value)
|
||||
del value
|
||||
if mask is None:
|
||||
out = optimized_attention(q, k, v, self.heads)
|
||||
else:
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head)
|
||||
)
|
||||
|
||||
out = optimized_attention_masked(q, k, v, self.heads, mask)
|
||||
return self.to_out(out)
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
print("Using xformers cross attention")
|
||||
CrossAttention = MemoryEfficientCrossAttention
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
print("Using pytorch cross attention")
|
||||
CrossAttention = CrossAttentionPytorch
|
||||
else:
|
||||
if args.use_split_cross_attention:
|
||||
print("Using split optimization for cross attention")
|
||||
CrossAttention = CrossAttentionDoggettx
|
||||
else:
|
||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
CrossAttention = CrossAttentionBirchSan
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
||||
|
||||
@ -6,7 +6,6 @@ import numpy as np
|
||||
from einops import rearrange
|
||||
from typing import Optional, Any
|
||||
|
||||
from ..attention import MemoryEfficientCrossAttention
|
||||
from comfy import model_management
|
||||
import comfy.ops
|
||||
|
||||
@ -194,6 +193,52 @@ def slice_attention(q, k, v):
|
||||
|
||||
return r1
|
||||
|
||||
def normal_attention(q, k, v):
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
v = v.reshape(b,c,h*w)
|
||||
|
||||
r1 = slice_attention(q, k, v)
|
||||
h_ = r1.reshape(b,c,h,w)
|
||||
del r1
|
||||
return h_
|
||||
|
||||
def xformers_attention(q, k, v):
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
try:
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||
out = out.transpose(1, 2).reshape(B, C, H, W)
|
||||
except NotImplementedError as e:
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
return out
|
||||
|
||||
def pytorch_attention(q, k, v):
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
try:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
print("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
return out
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
@ -221,6 +266,16 @@ class AttnBlock(nn.Module):
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
if model_management.xformers_enabled_vae():
|
||||
print("Using xformers attention in VAE")
|
||||
self.optimized_attention = xformers_attention
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
print("Using pytorch attention in VAE")
|
||||
self.optimized_attention = pytorch_attention
|
||||
else:
|
||||
print("Using split attention in VAE")
|
||||
self.optimized_attention = normal_attention
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
@ -228,161 +283,15 @@ class AttnBlock(nn.Module):
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
h_ = self.optimized_attention(q, k, v)
|
||||
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
v = v.reshape(b,c,h*w)
|
||||
|
||||
r1 = slice_attention(q, k, v)
|
||||
h_ = r1.reshape(b,c,h,w)
|
||||
del r1
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
class MemoryEfficientAttnBlock(nn.Module):
|
||||
"""
|
||||
Uses xformers efficient implementation,
|
||||
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
Note: this is a single-head self-attention operation
|
||||
"""
|
||||
#
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
try:
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||
out = out.transpose(1, 2).reshape(B, C, H, W)
|
||||
except NotImplementedError as e:
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
|
||||
out = self.proj_out(out)
|
||||
return x+out
|
||||
|
||||
class MemoryEfficientAttnBlockPytorch(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = comfy.ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
try:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
print("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
|
||||
out = self.proj_out(out)
|
||||
return x+out
|
||||
|
||||
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
||||
def forward(self, x, context=None, mask=None):
|
||||
b, c, h, w = x.shape
|
||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||
out = super().forward(x, context=context, mask=mask)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
|
||||
return x + out
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||
if model_management.xformers_enabled_vae() and attn_type == "vanilla":
|
||||
attn_type = "vanilla-xformers"
|
||||
if model_management.pytorch_attention_enabled() and attn_type == "vanilla":
|
||||
attn_type = "vanilla-pytorch"
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
assert attn_kwargs is None
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == "vanilla-xformers":
|
||||
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
||||
return MemoryEfficientAttnBlock(in_channels)
|
||||
elif attn_type == "vanilla-pytorch":
|
||||
return MemoryEfficientAttnBlockPytorch(in_channels)
|
||||
elif type == "memory-efficient-cross-attn":
|
||||
attn_kwargs["query_dim"] = in_channels
|
||||
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
|
||||
elif attn_type == "none":
|
||||
return nn.Identity(in_channels)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
return AttnBlock(in_channels)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
@ -632,7 +541,10 @@ class Decoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
||||
attn_type="vanilla", **ignorekwargs):
|
||||
conv_out_op=comfy.ops.Conv2d,
|
||||
resnet_op=ResnetBlock,
|
||||
attn_op=AttnBlock,
|
||||
**ignorekwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
@ -661,12 +573,12 @@ class Decoder(nn.Module):
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
self.mid.block_1 = resnet_op(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
self.mid.attn_1 = attn_op(block_in)
|
||||
self.mid.block_2 = resnet_op(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
@ -678,13 +590,13 @@ class Decoder(nn.Module):
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
block.append(resnet_op(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
attn.append(attn_op(block_in))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
@ -695,13 +607,13 @@ class Decoder(nn.Module):
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = comfy.ops.Conv2d(block_in,
|
||||
self.conv_out = conv_out_op(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
def forward(self, z, **kwargs):
|
||||
#assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
@ -712,16 +624,16 @@ class Decoder(nn.Module):
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
h = self.mid.block_1(h, temb, **kwargs)
|
||||
h = self.mid.attn_1(h, **kwargs)
|
||||
h = self.mid.block_2(h, temb, **kwargs)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
h = self.up[i_level].attn[i_block](h, **kwargs)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
@ -731,7 +643,7 @@ class Decoder(nn.Module):
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
h = self.conv_out(h, **kwargs)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
@ -296,8 +296,7 @@ class UNetModel(nn.Module):
|
||||
dims=2,
|
||||
num_classes=None,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
use_bf16=False,
|
||||
dtype=th.float32,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
@ -370,8 +369,7 @@ class UNetModel(nn.Module):
|
||||
self.conv_resample = conv_resample
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.dtype = th.bfloat16 if use_bf16 else self.dtype
|
||||
self.dtype = dtype
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
|
||||
@ -14,7 +14,7 @@ def count_blocks(state_dict_keys, prefix_string):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def detect_unet_config(state_dict, key_prefix, use_fp16):
|
||||
def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
|
||||
unet_config = {
|
||||
@ -32,7 +32,7 @@ def detect_unet_config(state_dict, key_prefix, use_fp16):
|
||||
else:
|
||||
unet_config["adm_in_channels"] = None
|
||||
|
||||
unet_config["use_fp16"] = use_fp16
|
||||
unet_config["dtype"] = dtype
|
||||
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
||||
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
||||
|
||||
@ -116,15 +116,15 @@ def model_config_from_unet_config(unet_config):
|
||||
print("no match", unet_config)
|
||||
return None
|
||||
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16, use_base_if_no_match=False):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
|
||||
model_config = model_config_from_unet_config(unet_config)
|
||||
if model_config is None and use_base_if_no_match:
|
||||
return comfy.supported_models_base.BASE(unet_config)
|
||||
else:
|
||||
return model_config
|
||||
|
||||
def unet_config_from_diffusers_unet(state_dict, use_fp16):
|
||||
def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
match = {}
|
||||
attention_resolutions = []
|
||||
|
||||
@ -147,47 +147,47 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
|
||||
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
|
||||
|
||||
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||
|
||||
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64}
|
||||
|
||||
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
|
||||
|
||||
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
|
||||
|
||||
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||
|
||||
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8}
|
||||
|
||||
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||
|
||||
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
|
||||
|
||||
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 9, 'model_channels': 320,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||
|
||||
@ -203,8 +203,8 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
|
||||
return unet_config
|
||||
return None
|
||||
|
||||
def model_config_from_diffusers_unet(state_dict, use_fp16):
|
||||
unet_config = unet_config_from_diffusers_unet(state_dict, use_fp16)
|
||||
def model_config_from_diffusers_unet(state_dict, dtype):
|
||||
unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
|
||||
if unet_config is not None:
|
||||
return model_config_from_unet_config(unet_config)
|
||||
return None
|
||||
|
||||
@ -154,14 +154,18 @@ def is_nvidia():
|
||||
return True
|
||||
return False
|
||||
|
||||
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
||||
ENABLE_PYTORCH_ATTENTION = False
|
||||
if args.use_pytorch_cross_attention:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
|
||||
VAE_DTYPE = torch.float32
|
||||
|
||||
try:
|
||||
if is_nvidia():
|
||||
torch_version = torch.version.__version__
|
||||
if int(torch_version[0]) >= 2:
|
||||
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch.cuda.is_bf16_supported():
|
||||
VAE_DTYPE = torch.bfloat16
|
||||
@ -186,7 +190,6 @@ if ENABLE_PYTORCH_ATTENTION:
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
|
||||
if args.lowvram:
|
||||
set_vram_to = VRAMState.LOW_VRAM
|
||||
@ -354,6 +357,8 @@ def load_models_gpu(models, memory_required=0):
|
||||
current_loaded_models.insert(0, current_loaded_models.pop(index))
|
||||
models_already_loaded.append(loaded_model)
|
||||
else:
|
||||
if hasattr(x, "model"):
|
||||
print(f"Requested to load {x.model.__class__.__name__}")
|
||||
models_to_load.append(loaded_model)
|
||||
|
||||
if len(models_to_load) == 0:
|
||||
@ -363,7 +368,7 @@ def load_models_gpu(models, memory_required=0):
|
||||
free_memory(extra_mem, d, models_already_loaded)
|
||||
return
|
||||
|
||||
print("loading new")
|
||||
print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
@ -405,7 +410,6 @@ def load_model_gpu(model):
|
||||
def cleanup_models():
|
||||
to_delete = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
print(sys.getrefcount(current_loaded_models[i].model))
|
||||
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
||||
to_delete = [i] + to_delete
|
||||
|
||||
@ -444,6 +448,13 @@ def unet_inital_load_device(parameters, dtype):
|
||||
else:
|
||||
return cpu_dev
|
||||
|
||||
def unet_dtype(device=None, model_params=0):
|
||||
if args.bf16_unet:
|
||||
return torch.bfloat16
|
||||
if should_use_fp16(device=device, model_params=model_params):
|
||||
return torch.float16
|
||||
return torch.float32
|
||||
|
||||
def text_encoder_offload_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
@ -656,7 +667,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
|
||||
return False
|
||||
|
||||
#FP16 is just broken on these cards
|
||||
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX"]
|
||||
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"]
|
||||
for x in nvidia_16_series:
|
||||
if x in props.name:
|
||||
return False
|
||||
|
||||
@ -107,6 +107,10 @@ class ModelPatcher:
|
||||
for k in patch_list:
|
||||
if hasattr(patch_list[k], "to"):
|
||||
patch_list[k] = patch_list[k].to(device)
|
||||
if "unet_wrapper_function" in self.model_options:
|
||||
wrap_func = self.model_options["unet_wrapper_function"]
|
||||
if hasattr(wrap_func, "to"):
|
||||
self.model_options["unet_wrapper_function"] = wrap_func.to(device)
|
||||
|
||||
def model_dtype(self):
|
||||
if hasattr(self.model, "get_dtype"):
|
||||
|
||||
59
comfy/sd.py
59
comfy/sd.py
@ -4,7 +4,7 @@ import math
|
||||
|
||||
from comfy import model_management
|
||||
from .ldm.util import instantiate_from_config
|
||||
from .ldm.models.autoencoder import AutoencoderKL
|
||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||
import yaml
|
||||
|
||||
import comfy.utils
|
||||
@ -140,21 +140,24 @@ class CLIP:
|
||||
return self.patcher.get_key_patches()
|
||||
|
||||
class VAE:
|
||||
def __init__(self, ckpt_path=None, device=None, config=None):
|
||||
def __init__(self, sd=None, device=None, config=None):
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
|
||||
if config is None:
|
||||
#default SD1.x/SD2.x VAE parameters
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss")
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
||||
else:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
if ckpt_path is not None:
|
||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0:
|
||||
print("Missing VAE keys", m)
|
||||
|
||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0:
|
||||
print("Missing VAE keys", m)
|
||||
|
||||
if len(u) > 0:
|
||||
print("Leftover VAE keys", u)
|
||||
|
||||
if device is None:
|
||||
device = model_management.vae_device()
|
||||
@ -183,7 +186,7 @@ class VAE:
|
||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).sample().float()
|
||||
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
|
||||
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||
@ -229,7 +232,7 @@ class VAE:
|
||||
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
|
||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu().float()
|
||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
|
||||
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
@ -327,7 +330,9 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
if "params" in model_config_params["unet_config"]:
|
||||
unet_config = model_config_params["unet_config"]["params"]
|
||||
if "use_fp16" in unet_config:
|
||||
fp16 = unet_config["use_fp16"]
|
||||
fp16 = unet_config.pop("use_fp16")
|
||||
if fp16:
|
||||
unet_config["dtype"] = torch.float16
|
||||
|
||||
noise_aug_config = None
|
||||
if "noise_aug_config" in model_config_params:
|
||||
@ -373,10 +378,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
model.load_model_weights(state_dict, "model.diffusion_model.")
|
||||
|
||||
if output_vae:
|
||||
w = WeightsLoader()
|
||||
vae = VAE(config=vae_config)
|
||||
w.first_stage_model = vae.first_stage_model
|
||||
load_model_weights(w, state_dict)
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(state_dict, {"first_stage_model.": ""}, filter_keys=True)
|
||||
vae = VAE(sd=vae_sd, config=vae_config)
|
||||
|
||||
if output_clip:
|
||||
w = WeightsLoader()
|
||||
@ -405,12 +408,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
clip_target = None
|
||||
|
||||
parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
|
||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", fp16)
|
||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
|
||||
@ -418,21 +421,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
if output_clipvision:
|
||||
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
||||
|
||||
dtype = torch.float32
|
||||
if fp16:
|
||||
dtype = torch.float16
|
||||
|
||||
if output_model:
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, dtype)
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
|
||||
model.load_model_weights(sd, "model.diffusion_model.")
|
||||
|
||||
if output_vae:
|
||||
vae = VAE()
|
||||
w = WeightsLoader()
|
||||
w.first_stage_model = vae.first_stage_model
|
||||
load_model_weights(w, sd)
|
||||
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
|
||||
vae = VAE(sd=vae_sd)
|
||||
|
||||
if output_clip:
|
||||
w = WeightsLoader()
|
||||
@ -458,15 +455,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
def load_unet(unet_path): #load unet in diffusers format
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
if "input_blocks.0.0.weight" in sd: #ldm
|
||||
model_config = model_detection.model_config_from_unet(sd, "", fp16)
|
||||
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
new_sd = sd
|
||||
|
||||
else: #diffusers
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
|
||||
if model_config is None:
|
||||
print("ERROR UNSUPPORTED UNET", unet_path)
|
||||
return None
|
||||
|
||||
@ -6,6 +6,8 @@ Tiny AutoEncoder for Stable Diffusion
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import comfy.utils
|
||||
|
||||
def conv(n_in, n_out, **kwargs):
|
||||
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||
|
||||
@ -50,9 +52,9 @@ class TAESD(nn.Module):
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder()
|
||||
if encoder_path is not None:
|
||||
self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
|
||||
self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
||||
if decoder_path is not None:
|
||||
self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
|
||||
self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
|
||||
|
||||
@staticmethod
|
||||
def scale_latents(x):
|
||||
|
||||
@ -47,12 +47,17 @@ def state_dict_key_replace(state_dict, keys_to_replace):
|
||||
state_dict[keys_to_replace[x]] = state_dict.pop(x)
|
||||
return state_dict
|
||||
|
||||
def state_dict_prefix_replace(state_dict, replace_prefix):
|
||||
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
|
||||
if filter_keys:
|
||||
out = {}
|
||||
else:
|
||||
out = state_dict
|
||||
for rp in replace_prefix:
|
||||
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
|
||||
for x in replace:
|
||||
state_dict[x[1]] = state_dict.pop(x[0])
|
||||
return state_dict
|
||||
w = state_dict.pop(x[0])
|
||||
out[x[1]] = w
|
||||
return out
|
||||
|
||||
|
||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
@ -408,6 +413,10 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
|
||||
output[b:b+1] = out/out_div
|
||||
return output
|
||||
|
||||
PROGRESS_BAR_ENABLED = True
|
||||
def set_progress_bar_enabled(enabled):
|
||||
global PROGRESS_BAR_ENABLED
|
||||
PROGRESS_BAR_ENABLED = enabled
|
||||
|
||||
PROGRESS_BAR_HOOK = None
|
||||
def set_progress_bar_global_hook(function):
|
||||
|
||||
@ -3,6 +3,7 @@ import comfy.sample
|
||||
from comfy.k_diffusion import sampling as k_diffusion_sampling
|
||||
import latent_preview
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
|
||||
class BasicScheduler:
|
||||
@ -219,7 +220,7 @@ class SamplerCustom:
|
||||
x0_output = {}
|
||||
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
|
||||
|
||||
disable_pbar = False
|
||||
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
||||
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
|
||||
|
||||
out = latent.copy()
|
||||
|
||||
@ -19,6 +19,7 @@ def load_hypernetwork_patch(path, strength):
|
||||
"tanh": torch.nn.Tanh,
|
||||
"sigmoid": torch.nn.Sigmoid,
|
||||
"softsign": torch.nn.Softsign,
|
||||
"mish": torch.nn.Mish,
|
||||
}
|
||||
|
||||
if activation_func not in valid_activation:
|
||||
@ -42,7 +43,8 @@ def load_hypernetwork_patch(path, strength):
|
||||
linears = list(map(lambda a: a[:-len(".weight")], linears))
|
||||
layers = []
|
||||
|
||||
for i in range(len(linears)):
|
||||
i = 0
|
||||
while i < len(linears):
|
||||
lin_name = linears[i]
|
||||
last_layer = (i == (len(linears) - 1))
|
||||
penultimate_layer = (i == (len(linears) - 2))
|
||||
@ -56,10 +58,17 @@ def load_hypernetwork_patch(path, strength):
|
||||
if (not last_layer) or (activate_output):
|
||||
layers.append(valid_activation[activation_func]())
|
||||
if is_layer_norm:
|
||||
layers.append(torch.nn.LayerNorm(lin_weight.shape[0]))
|
||||
i += 1
|
||||
ln_name = linears[i]
|
||||
ln_weight = attn_weights['{}.weight'.format(ln_name)]
|
||||
ln_bias = attn_weights['{}.bias'.format(ln_name)]
|
||||
ln = torch.nn.LayerNorm(ln_weight.shape[0])
|
||||
ln.load_state_dict({"weight": ln_weight, "bias": ln_bias})
|
||||
layers.append(ln)
|
||||
if use_dropout:
|
||||
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
|
||||
layers.append(torch.nn.Dropout(p=0.3))
|
||||
i += 1
|
||||
|
||||
output.append(torch.nn.Sequential(*layers))
|
||||
out[dim] = torch.nn.ModuleList(output)
|
||||
|
||||
@ -240,8 +240,8 @@ class MaskComposite:
|
||||
right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2]))
|
||||
visible_width, visible_height = (right - left, bottom - top,)
|
||||
|
||||
source_portion = source[:visible_height, :visible_width]
|
||||
destination_portion = destination[top:bottom, left:right]
|
||||
source_portion = source[:, :visible_height, :visible_width]
|
||||
destination_portion = destination[:, top:bottom, left:right]
|
||||
|
||||
if operation == "multiply":
|
||||
output[:, top:bottom, left:right] = destination_portion * source_portion
|
||||
@ -282,10 +282,10 @@ class FeatherMask:
|
||||
def feather(self, mask, left, top, right, bottom):
|
||||
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
|
||||
|
||||
left = min(left, output.shape[1])
|
||||
right = min(right, output.shape[1])
|
||||
top = min(top, output.shape[0])
|
||||
bottom = min(bottom, output.shape[0])
|
||||
left = min(left, output.shape[-1])
|
||||
right = min(right, output.shape[-1])
|
||||
top = min(top, output.shape[-2])
|
||||
bottom = min(bottom, output.shape[-2])
|
||||
|
||||
for x in range(left):
|
||||
feather_rate = (x + 1.0) / left
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.model_base
|
||||
import comfy.model_management
|
||||
|
||||
import folder_paths
|
||||
import json
|
||||
@ -178,6 +179,95 @@ class CheckpointSave:
|
||||
comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
|
||||
return {}
|
||||
|
||||
class CLIPSave:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip": ("CLIP",),
|
||||
"filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
prompt_info = ""
|
||||
if prompt is not None:
|
||||
prompt_info = json.dumps(prompt)
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
metadata["prompt"] = prompt_info
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
comfy.model_management.load_models_gpu([clip.load_model()])
|
||||
clip_sd = clip.get_sd()
|
||||
|
||||
for prefix in ["clip_l.", "clip_g.", ""]:
|
||||
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
|
||||
current_clip_sd = {}
|
||||
for x in k:
|
||||
current_clip_sd[x] = clip_sd.pop(x)
|
||||
if len(current_clip_sd) == 0:
|
||||
continue
|
||||
|
||||
p = prefix[:-1]
|
||||
replace_prefix = {}
|
||||
filename_prefix_ = filename_prefix
|
||||
if len(p) > 0:
|
||||
filename_prefix_ = "{}_{}".format(filename_prefix_, p)
|
||||
replace_prefix[prefix] = ""
|
||||
replace_prefix["transformer."] = ""
|
||||
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir)
|
||||
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
|
||||
|
||||
comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
|
||||
return {}
|
||||
|
||||
class VAESave:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "vae": ("VAE",),
|
||||
"filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
prompt_info = ""
|
||||
if prompt is not None:
|
||||
prompt_info = json.dumps(prompt)
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
metadata["prompt"] = prompt_info
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
|
||||
comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
|
||||
return {}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSimple": ModelMergeSimple,
|
||||
@ -186,4 +276,6 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeAdd": ModelAdd,
|
||||
"CheckpointSave": CheckpointSave,
|
||||
"CLIPMergeSimple": CLIPMergeSimple,
|
||||
"CLIPSave": CLIPSave,
|
||||
"VAESave": VAESave,
|
||||
}
|
||||
|
||||
19
execution.py
19
execution.py
@ -2,6 +2,7 @@ import os
|
||||
import sys
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import heapq
|
||||
import traceback
|
||||
@ -156,7 +157,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
except comfy.model_management.InterruptProcessingException as iex:
|
||||
print("Processing interrupted")
|
||||
logging.info("Processing interrupted")
|
||||
|
||||
# skip formatting inputs/outputs
|
||||
error_details = {
|
||||
@ -177,8 +178,8 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
|
||||
for node_id, node_outputs in outputs.items():
|
||||
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
|
||||
|
||||
print("!!! Exception during processing !!!")
|
||||
print(traceback.format_exc())
|
||||
logging.error("!!! Exception during processing !!!")
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
error_details = {
|
||||
"node_id": unique_id,
|
||||
@ -636,11 +637,11 @@ def validate_prompt(prompt):
|
||||
if valid is True:
|
||||
good_outputs.add(o)
|
||||
else:
|
||||
print(f"Failed to validate prompt for output {o}:")
|
||||
logging.error(f"Failed to validate prompt for output {o}:")
|
||||
if len(reasons) > 0:
|
||||
print("* (prompt):")
|
||||
logging.error("* (prompt):")
|
||||
for reason in reasons:
|
||||
print(f" - {reason['message']}: {reason['details']}")
|
||||
logging.error(f" - {reason['message']}: {reason['details']}")
|
||||
errors += [(o, reasons)]
|
||||
for node_id, result in validated.items():
|
||||
valid = result[0]
|
||||
@ -656,11 +657,11 @@ def validate_prompt(prompt):
|
||||
"dependent_outputs": [],
|
||||
"class_type": class_type
|
||||
}
|
||||
print(f"* {class_type} {node_id}:")
|
||||
logging.error(f"* {class_type} {node_id}:")
|
||||
for reason in reasons:
|
||||
print(f" - {reason['message']}: {reason['details']}")
|
||||
logging.error(f" - {reason['message']}: {reason['details']}")
|
||||
node_errors[node_id]["dependent_outputs"].append(o)
|
||||
print("Output will be ignored")
|
||||
logging.error("Output will be ignored")
|
||||
|
||||
if len(good_outputs) == 0:
|
||||
errors_list = []
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#Rename this to extra_model_paths.yaml and ComfyUI will load it
|
||||
|
||||
|
||||
#config for a1111 ui
|
||||
#all you have to do is change the base_path to where yours is installed
|
||||
a111:
|
||||
@ -19,6 +20,21 @@ a111:
|
||||
hypernetworks: models/hypernetworks
|
||||
controlnet: models/ControlNet
|
||||
|
||||
#config for comfyui
|
||||
#your base path should be either an existing comfy install or a central folder where you store all of your models, loras, etc.
|
||||
|
||||
#comfyui:
|
||||
# base_path: path/to/comfyui/
|
||||
# checkpoints: models/checkpoints/
|
||||
# clip: models/clip/
|
||||
# clip_vision: models/clip_vision/
|
||||
# configs: models/configs/
|
||||
# controlnet: models/controlnet/
|
||||
# embeddings: models/embeddings/
|
||||
# loras: models/loras/
|
||||
# upscale_models: models/upscale_models/
|
||||
# vae: models/vae/
|
||||
|
||||
#other_ui:
|
||||
# base_path: path/to/ui
|
||||
# checkpoints: models/checkpoints
|
||||
|
||||
@ -29,6 +29,8 @@ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes
|
||||
|
||||
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
|
||||
|
||||
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
|
||||
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
|
||||
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
||||
@ -46,6 +48,10 @@ def set_temp_directory(temp_dir):
|
||||
global temp_directory
|
||||
temp_directory = temp_dir
|
||||
|
||||
def set_input_directory(input_dir):
|
||||
global input_directory
|
||||
input_directory = input_dir
|
||||
|
||||
def get_output_directory():
|
||||
global output_directory
|
||||
return output_directory
|
||||
@ -140,7 +146,7 @@ def recursive_search(directory, excluded_dir_names=None):
|
||||
return result, dirs
|
||||
|
||||
def filter_files_extensions(files, extensions):
|
||||
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files)))
|
||||
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))
|
||||
|
||||
|
||||
|
||||
|
||||
@ -56,7 +56,12 @@ def get_previewer(device, latent_format):
|
||||
# TODO previewer methods
|
||||
taesd_decoder_path = None
|
||||
if latent_format.taesd_decoder_name is not None:
|
||||
taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name)
|
||||
taesd_decoder_path = next(
|
||||
(fn for fn in folder_paths.get_filename_list("vae_approx")
|
||||
if fn.startswith(latent_format.taesd_decoder_name)),
|
||||
""
|
||||
)
|
||||
taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path)
|
||||
|
||||
if method == LatentPreviewMethod.Auto:
|
||||
method = LatentPreviewMethod.Latent2RGB
|
||||
|
||||
10
main.py
10
main.py
@ -175,6 +175,16 @@ if __name__ == "__main__":
|
||||
print(f"Setting output directory to: {output_dir}")
|
||||
folder_paths.set_output_directory(output_dir)
|
||||
|
||||
#These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
|
||||
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
|
||||
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
||||
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
||||
|
||||
if args.input_directory:
|
||||
input_dir = os.path.abspath(args.input_directory)
|
||||
print(f"Setting input directory to: {input_dir}")
|
||||
folder_paths.set_input_directory(input_dir)
|
||||
|
||||
if args.quick_test_for_ci:
|
||||
exit(0)
|
||||
|
||||
|
||||
7
nodes.py
7
nodes.py
@ -584,7 +584,8 @@ class VAELoader:
|
||||
#TODO: scale factor?
|
||||
def load_vae(self, vae_name):
|
||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
||||
vae = comfy.sd.VAE(ckpt_path=vae_path)
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
vae = comfy.sd.VAE(sd=sd)
|
||||
return (vae,)
|
||||
|
||||
class ControlNetLoader:
|
||||
@ -1202,7 +1203,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
|
||||
noise_mask = latent["noise_mask"]
|
||||
|
||||
callback = latent_preview.prepare_callback(model, steps)
|
||||
disable_pbar = False
|
||||
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
||||
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
||||
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
|
||||
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||
@ -1660,7 +1661,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"KSampler": "KSampler",
|
||||
"KSamplerAdvanced": "KSampler (Advanced)",
|
||||
# Loaders
|
||||
"CheckpointLoader": "Load Checkpoint (With Config)",
|
||||
"CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
|
||||
"CheckpointLoaderSimple": "Load Checkpoint",
|
||||
"VAELoader": "Load VAE",
|
||||
"LoraLoader": "Load LoRA",
|
||||
|
||||
@ -47,7 +47,7 @@
|
||||
" !git pull\n",
|
||||
"\n",
|
||||
"!echo -= Install dependencies =-\n",
|
||||
"!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117"
|
||||
"!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -5,6 +5,61 @@ function setNodeMode(node, mode) {
|
||||
node.graph.change();
|
||||
}
|
||||
|
||||
function addNodesToGroup(group, nodes=[]) {
|
||||
var x1, y1, x2, y2;
|
||||
var nx1, ny1, nx2, ny2;
|
||||
var node;
|
||||
|
||||
x1 = y1 = x2 = y2 = -1;
|
||||
nx1 = ny1 = nx2 = ny2 = -1;
|
||||
|
||||
for (var n of [group._nodes, nodes]) {
|
||||
for (var i in n) {
|
||||
node = n[i]
|
||||
|
||||
nx1 = node.pos[0]
|
||||
ny1 = node.pos[1]
|
||||
nx2 = node.pos[0] + node.size[0]
|
||||
ny2 = node.pos[1] + node.size[1]
|
||||
|
||||
if (node.type != "Reroute") {
|
||||
ny1 -= LiteGraph.NODE_TITLE_HEIGHT;
|
||||
}
|
||||
|
||||
if (node.flags?.collapsed) {
|
||||
ny2 = ny1 + LiteGraph.NODE_TITLE_HEIGHT;
|
||||
|
||||
if (node?._collapsed_width) {
|
||||
nx2 = nx1 + Math.round(node._collapsed_width);
|
||||
}
|
||||
}
|
||||
|
||||
if (x1 == -1 || nx1 < x1) {
|
||||
x1 = nx1;
|
||||
}
|
||||
|
||||
if (y1 == -1 || ny1 < y1) {
|
||||
y1 = ny1;
|
||||
}
|
||||
|
||||
if (x2 == -1 || nx2 > x2) {
|
||||
x2 = nx2;
|
||||
}
|
||||
|
||||
if (y2 == -1 || ny2 > y2) {
|
||||
y2 = ny2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var padding = 10;
|
||||
|
||||
y1 = y1 - Math.round(group.font_size * 1.4);
|
||||
|
||||
group.pos = [x1 - padding, y1 - padding];
|
||||
group.size = [x2 - x1 + padding * 2, y2 - y1 + padding * 2];
|
||||
}
|
||||
|
||||
app.registerExtension({
|
||||
name: "Comfy.GroupOptions",
|
||||
setup() {
|
||||
@ -14,6 +69,17 @@ app.registerExtension({
|
||||
const options = orig.apply(this, arguments);
|
||||
const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]);
|
||||
if (!group) {
|
||||
options.push({
|
||||
content: "Add Group For Selected Nodes",
|
||||
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
|
||||
callback: () => {
|
||||
var group = new LiteGraph.LGraphGroup();
|
||||
addNodesToGroup(group, this.selected_nodes)
|
||||
app.canvas.graph.add(group);
|
||||
this.graph.change();
|
||||
}
|
||||
});
|
||||
|
||||
return options;
|
||||
}
|
||||
|
||||
@ -21,6 +87,15 @@ app.registerExtension({
|
||||
group.recomputeInsideNodes();
|
||||
const nodesInGroup = group._nodes;
|
||||
|
||||
options.push({
|
||||
content: "Add Selected Nodes To Group",
|
||||
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
|
||||
callback: () => {
|
||||
addNodesToGroup(group, this.selected_nodes)
|
||||
this.graph.change();
|
||||
}
|
||||
});
|
||||
|
||||
// No nodes in group, return default options
|
||||
if (nodesInGroup.length === 0) {
|
||||
return options;
|
||||
@ -38,6 +113,23 @@ app.registerExtension({
|
||||
}
|
||||
}
|
||||
|
||||
options.push({
|
||||
content: "Fit Group To Nodes",
|
||||
callback: () => {
|
||||
addNodesToGroup(group)
|
||||
this.graph.change();
|
||||
}
|
||||
});
|
||||
|
||||
options.push({
|
||||
content: "Select Nodes",
|
||||
callback: () => {
|
||||
this.selectNodes(nodesInGroup);
|
||||
this.graph.change();
|
||||
this.canvas.focus();
|
||||
}
|
||||
});
|
||||
|
||||
// Modes
|
||||
// 0: Always
|
||||
// 1: On Event
|
||||
|
||||
@ -3796,7 +3796,7 @@
|
||||
out = out || new Float32Array(4);
|
||||
out[0] = this.pos[0] - 4;
|
||||
out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT;
|
||||
out[2] = this.size[0] + 4;
|
||||
out[2] = this.flags.collapsed ? (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) : this.size[0] + 4;
|
||||
out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT;
|
||||
|
||||
if (this.onBounding) {
|
||||
|
||||
@ -928,6 +928,16 @@ export class ComfyApp {
|
||||
block_default = true;
|
||||
}
|
||||
|
||||
// Alt + C collapse/uncollapse
|
||||
if (e.key === 'c' && e.altKey) {
|
||||
if (this.selected_nodes) {
|
||||
for (var i in this.selected_nodes) {
|
||||
this.selected_nodes[i].collapse()
|
||||
}
|
||||
}
|
||||
block_default = true;
|
||||
}
|
||||
|
||||
// Ctrl+C Copy
|
||||
if ((e.key === 'c') && (e.metaKey || e.ctrlKey)) {
|
||||
// Trigger onCopy
|
||||
@ -1592,7 +1602,7 @@ export class ComfyApp {
|
||||
all_inputs = all_inputs.concat(Object.keys(parent.inputs))
|
||||
for (let parent_input in all_inputs) {
|
||||
parent_input = all_inputs[parent_input];
|
||||
if (parent.inputs[parent_input].type === node.inputs[i].type) {
|
||||
if (parent.inputs[parent_input]?.type === node.inputs[i].type) {
|
||||
link = parent.getInputLink(parent_input);
|
||||
if (link) {
|
||||
parent = parent.getInputNode(parent_input);
|
||||
|
||||
@ -809,7 +809,8 @@ export class ComfyUI {
|
||||
if (
|
||||
this.lastQueueSize != 0 &&
|
||||
status.exec_info.queue_remaining == 0 &&
|
||||
document.getElementById("autoQueueCheckbox").checked
|
||||
document.getElementById("autoQueueCheckbox").checked &&
|
||||
! app.lastExecutionError
|
||||
) {
|
||||
app.queuePrompt(0, this.batchCount);
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user