mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-08 10:47:32 +08:00
Merge remote-tracking branch 'origin/master' into group-nodes
This commit is contained in:
commit
288a1c6242
11
.github/workflows/test-ui.yaml
vendored
11
.github/workflows/test-ui.yaml
vendored
@ -10,8 +10,17 @@ jobs:
|
|||||||
- uses: actions/setup-node@v3
|
- uses: actions/setup-node@v3
|
||||||
with:
|
with:
|
||||||
node-version: 18
|
node-version: 18
|
||||||
|
- uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
- name: Install requirements
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
pip install -r requirements.txt
|
||||||
- name: Run Tests
|
- name: Run Tests
|
||||||
run: |
|
run: |
|
||||||
npm install
|
npm ci
|
||||||
|
npm run test:generate
|
||||||
npm test
|
npm test
|
||||||
working-directory: ./tests-ui
|
working-directory: ./tests-ui
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -14,3 +14,4 @@ venv/
|
|||||||
/web/extensions/*
|
/web/extensions/*
|
||||||
!/web/extensions/logging.js.example
|
!/web/extensions/logging.js.example
|
||||||
!/web/extensions/core/
|
!/web/extensions/core/
|
||||||
|
/tests-ui/data/object_info.json
|
||||||
@ -92,8 +92,11 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
|
||||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||||
else:
|
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
clip = ClipVisionModel(json_config)
|
clip = ClipVisionModel(json_config)
|
||||||
m, u = clip.load_sd(sd)
|
m, u = clip.load_sd(sd)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
|
|||||||
@ -31,6 +31,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
|
|||||||
|
|
||||||
vae = None
|
vae = None
|
||||||
if output_vae:
|
if output_vae:
|
||||||
vae = comfy.sd.VAE(ckpt_path=vae_path)
|
sd = comfy.utils.load_torch_file(vae_path)
|
||||||
|
vae = comfy.sd.VAE(sd=sd)
|
||||||
|
|
||||||
return (unet, clip, vae)
|
return (unet, clip, vae)
|
||||||
|
|||||||
@ -713,8 +713,8 @@ class UniPC:
|
|||||||
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
||||||
atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
|
atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
|
||||||
):
|
):
|
||||||
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
# t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
||||||
t_T = self.noise_schedule.T if t_start is None else t_start
|
# t_T = self.noise_schedule.T if t_start is None else t_start
|
||||||
device = x.device
|
device = x.device
|
||||||
steps = len(timesteps) - 1
|
steps = len(timesteps) - 1
|
||||||
if method == 'multistep':
|
if method == 'multistep':
|
||||||
@ -769,8 +769,8 @@ class UniPC:
|
|||||||
callback(step_index, model_prev_list[-1], x, steps)
|
callback(step_index, model_prev_list[-1], x, steps)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if denoise_to_zero:
|
# if denoise_to_zero:
|
||||||
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
# x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -833,21 +833,33 @@ def expand_dims(v, dims):
|
|||||||
return v[(...,) + (None,)*(dims - 1)]
|
return v[(...,) + (None,)*(dims - 1)]
|
||||||
|
|
||||||
|
|
||||||
|
class SigmaConvert:
|
||||||
|
schedule = ""
|
||||||
|
def marginal_log_mean_coeff(self, sigma):
|
||||||
|
return 0.5 * torch.log(1 / ((sigma * sigma) + 1))
|
||||||
|
|
||||||
|
def marginal_alpha(self, t):
|
||||||
|
return torch.exp(self.marginal_log_mean_coeff(t))
|
||||||
|
|
||||||
|
def marginal_std(self, t):
|
||||||
|
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
||||||
|
|
||||||
|
def marginal_lambda(self, t):
|
||||||
|
"""
|
||||||
|
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
||||||
|
"""
|
||||||
|
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
||||||
|
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
||||||
|
return log_mean_coeff - log_std
|
||||||
|
|
||||||
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
||||||
to_zero = False
|
timesteps = sigmas.clone()
|
||||||
if sigmas[-1] == 0:
|
if sigmas[-1] == 0:
|
||||||
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
|
timesteps = sigmas[:]
|
||||||
to_zero = True
|
timesteps[-1] = 0.001
|
||||||
else:
|
else:
|
||||||
timesteps = sigmas.clone()
|
timesteps = sigmas.clone()
|
||||||
|
ns = SigmaConvert()
|
||||||
alphas_cumprod = model.inner_model.alphas_cumprod
|
|
||||||
|
|
||||||
for s in range(timesteps.shape[0]):
|
|
||||||
timesteps[s] = (model.sigma_to_discrete_timestep(timesteps[s]) / 1000) + (1 / len(alphas_cumprod))
|
|
||||||
|
|
||||||
ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
|
||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
img = image * ns.marginal_alpha(timesteps[0])
|
img = image * ns.marginal_alpha(timesteps[0])
|
||||||
@ -859,16 +871,10 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
|||||||
else:
|
else:
|
||||||
img = noise
|
img = noise
|
||||||
|
|
||||||
if to_zero:
|
|
||||||
timesteps[-1] = (1 / len(alphas_cumprod))
|
|
||||||
|
|
||||||
device = noise.device
|
|
||||||
|
|
||||||
|
|
||||||
model_type = "noise"
|
model_type = "noise"
|
||||||
|
|
||||||
model_fn = model_wrapper(
|
model_fn = model_wrapper(
|
||||||
model.predict_eps_discrete_timestep,
|
model.predict_eps_sigma,
|
||||||
ns,
|
ns,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
guidance_type="uncond",
|
guidance_type="uncond",
|
||||||
@ -878,6 +884,5 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
|||||||
order = min(3, len(timesteps) - 1)
|
order = min(3, len(timesteps) - 1)
|
||||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
|
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
|
||||||
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
|
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
|
||||||
if not to_zero:
|
x /= ns.marginal_alpha(timesteps[-1])
|
||||||
x /= ns.marginal_alpha(timesteps[-1])
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -97,6 +97,10 @@ class DiscreteSchedule(nn.Module):
|
|||||||
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
|
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
|
||||||
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
|
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
|
||||||
|
|
||||||
|
def predict_eps_sigma(self, input, sigma, **kwargs):
|
||||||
|
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
|
||||||
|
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
|
||||||
|
|
||||||
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
||||||
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
||||||
noise)."""
|
noise)."""
|
||||||
|
|||||||
@ -2,67 +2,66 @@ import torch
|
|||||||
# import pytorch_lightning as pl
|
# import pytorch_lightning as pl
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from comfy.ldm.modules.diffusionmodules.model import Encoder, Decoder
|
|
||||||
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||||
|
|
||||||
from comfy.ldm.util import instantiate_from_config
|
from comfy.ldm.util import instantiate_from_config
|
||||||
from comfy.ldm.modules.ema import LitEma
|
from comfy.ldm.modules.ema import LitEma
|
||||||
|
|
||||||
# class AutoencoderKL(pl.LightningModule):
|
class DiagonalGaussianRegularizer(torch.nn.Module):
|
||||||
class AutoencoderKL(torch.nn.Module):
|
def __init__(self, sample: bool = True):
|
||||||
def __init__(self,
|
|
||||||
ddconfig,
|
|
||||||
lossconfig,
|
|
||||||
embed_dim,
|
|
||||||
ckpt_path=None,
|
|
||||||
ignore_keys=[],
|
|
||||||
image_key="image",
|
|
||||||
colorize_nlabels=None,
|
|
||||||
monitor=None,
|
|
||||||
ema_decay=None,
|
|
||||||
learn_logvar=False
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.learn_logvar = learn_logvar
|
self.sample = sample
|
||||||
self.image_key = image_key
|
|
||||||
self.encoder = Encoder(**ddconfig)
|
def get_trainable_parameters(self) -> Any:
|
||||||
self.decoder = Decoder(**ddconfig)
|
yield from ()
|
||||||
self.loss = instantiate_from_config(lossconfig)
|
|
||||||
assert ddconfig["double_z"]
|
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
||||||
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
log = dict()
|
||||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
posterior = DiagonalGaussianDistribution(z)
|
||||||
self.embed_dim = embed_dim
|
if self.sample:
|
||||||
if colorize_nlabels is not None:
|
z = posterior.sample()
|
||||||
assert type(colorize_nlabels)==int
|
else:
|
||||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
z = posterior.mode()
|
||||||
|
kl_loss = posterior.kl()
|
||||||
|
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
||||||
|
log["kl_loss"] = kl_loss
|
||||||
|
return z, log
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractAutoencoder(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
|
||||||
|
unCLIP models, etc. Hence, it is fairly general, and specific features
|
||||||
|
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ema_decay: Union[None, float] = None,
|
||||||
|
monitor: Union[None, str] = None,
|
||||||
|
input_key: str = "jpg",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_key = input_key
|
||||||
|
self.use_ema = ema_decay is not None
|
||||||
if monitor is not None:
|
if monitor is not None:
|
||||||
self.monitor = monitor
|
self.monitor = monitor
|
||||||
|
|
||||||
self.use_ema = ema_decay is not None
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.ema_decay = ema_decay
|
|
||||||
assert 0. < ema_decay < 1.
|
|
||||||
self.model_ema = LitEma(self, decay=ema_decay)
|
self.model_ema = LitEma(self, decay=ema_decay)
|
||||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
if ckpt_path is not None:
|
def get_input(self, batch) -> Any:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
raise NotImplementedError()
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
def on_train_batch_end(self, *args, **kwargs):
|
||||||
if path.lower().endswith(".safetensors"):
|
# for EMA computation
|
||||||
import safetensors.torch
|
if self.use_ema:
|
||||||
sd = safetensors.torch.load_file(path, device="cpu")
|
self.model_ema(self)
|
||||||
else:
|
|
||||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
|
||||||
keys = list(sd.keys())
|
|
||||||
for k in keys:
|
|
||||||
for ik in ignore_keys:
|
|
||||||
if k.startswith(ik):
|
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
|
||||||
del sd[k]
|
|
||||||
self.load_state_dict(sd, strict=False)
|
|
||||||
print(f"Restored from {path}")
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def ema_scope(self, context=None):
|
def ema_scope(self, context=None):
|
||||||
@ -70,154 +69,159 @@ class AutoencoderKL(torch.nn.Module):
|
|||||||
self.model_ema.store(self.parameters())
|
self.model_ema.store(self.parameters())
|
||||||
self.model_ema.copy_to(self)
|
self.model_ema.copy_to(self)
|
||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Switched to EMA weights")
|
logpy.info(f"{context}: Switched to EMA weights")
|
||||||
try:
|
try:
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.model_ema.restore(self.parameters())
|
self.model_ema.restore(self.parameters())
|
||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Restored training weights")
|
logpy.info(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def on_train_batch_end(self, *args, **kwargs):
|
def encode(self, *args, **kwargs) -> torch.Tensor:
|
||||||
if self.use_ema:
|
raise NotImplementedError("encode()-method of abstract base class called")
|
||||||
self.model_ema(self)
|
|
||||||
|
|
||||||
def encode(self, x):
|
def decode(self, *args, **kwargs) -> torch.Tensor:
|
||||||
h = self.encoder(x)
|
raise NotImplementedError("decode()-method of abstract base class called")
|
||||||
moments = self.quant_conv(h)
|
|
||||||
posterior = DiagonalGaussianDistribution(moments)
|
|
||||||
return posterior
|
|
||||||
|
|
||||||
def decode(self, z):
|
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||||
z = self.post_quant_conv(z)
|
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||||
dec = self.decoder(z)
|
return get_obj_from_str(cfg["target"])(
|
||||||
return dec
|
params, lr=lr, **cfg.get("params", dict())
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, input, sample_posterior=True):
|
def configure_optimizers(self) -> Any:
|
||||||
posterior = self.encode(input)
|
raise NotImplementedError()
|
||||||
if sample_posterior:
|
|
||||||
z = posterior.sample()
|
|
||||||
else:
|
|
||||||
z = posterior.mode()
|
|
||||||
dec = self.decode(z)
|
|
||||||
return dec, posterior
|
|
||||||
|
|
||||||
def get_input(self, batch, k):
|
|
||||||
x = batch[k]
|
|
||||||
if len(x.shape) == 3:
|
|
||||||
x = x[..., None]
|
|
||||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
|
||||||
return x
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
class AutoencodingEngine(AbstractAutoencoder):
|
||||||
inputs = self.get_input(batch, self.image_key)
|
"""
|
||||||
reconstructions, posterior = self(inputs)
|
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
|
||||||
|
(we also restore them explicitly as special cases for legacy reasons).
|
||||||
|
Regularizations such as KL or VQ are moved to the regularizer class.
|
||||||
|
"""
|
||||||
|
|
||||||
if optimizer_idx == 0:
|
def __init__(
|
||||||
# train encoder+decoder+logvar
|
self,
|
||||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
*args,
|
||||||
last_layer=self.get_last_layer(), split="train")
|
encoder_config: Dict,
|
||||||
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
decoder_config: Dict,
|
||||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
regularizer_config: Dict,
|
||||||
return aeloss
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
if optimizer_idx == 1:
|
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
||||||
# train the discriminator
|
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
||||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
self.regularization: AbstractRegularizer = instantiate_from_config(
|
||||||
last_layer=self.get_last_layer(), split="train")
|
regularizer_config
|
||||||
|
)
|
||||||
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
|
||||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
|
||||||
return discloss
|
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
|
||||||
log_dict = self._validation_step(batch, batch_idx)
|
|
||||||
with self.ema_scope():
|
|
||||||
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
|
||||||
return log_dict
|
|
||||||
|
|
||||||
def _validation_step(self, batch, batch_idx, postfix=""):
|
|
||||||
inputs = self.get_input(batch, self.image_key)
|
|
||||||
reconstructions, posterior = self(inputs)
|
|
||||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
|
||||||
last_layer=self.get_last_layer(), split="val"+postfix)
|
|
||||||
|
|
||||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
|
||||||
last_layer=self.get_last_layer(), split="val"+postfix)
|
|
||||||
|
|
||||||
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
|
|
||||||
self.log_dict(log_dict_ae)
|
|
||||||
self.log_dict(log_dict_disc)
|
|
||||||
return self.log_dict
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
lr = self.learning_rate
|
|
||||||
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
|
|
||||||
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
|
|
||||||
if self.learn_logvar:
|
|
||||||
print(f"{self.__class__.__name__}: Learning logvar")
|
|
||||||
ae_params_list.append(self.loss.logvar)
|
|
||||||
opt_ae = torch.optim.Adam(ae_params_list,
|
|
||||||
lr=lr, betas=(0.5, 0.9))
|
|
||||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
|
||||||
lr=lr, betas=(0.5, 0.9))
|
|
||||||
return [opt_ae, opt_disc], []
|
|
||||||
|
|
||||||
def get_last_layer(self):
|
def get_last_layer(self):
|
||||||
return self.decoder.conv_out.weight
|
return self.decoder.get_last_layer()
|
||||||
|
|
||||||
@torch.no_grad()
|
def encode(
|
||||||
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
|
self,
|
||||||
log = dict()
|
x: torch.Tensor,
|
||||||
x = self.get_input(batch, self.image_key)
|
return_reg_log: bool = False,
|
||||||
x = x.to(self.device)
|
unregularized: bool = False,
|
||||||
if not only_inputs:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||||
xrec, posterior = self(x)
|
z = self.encoder(x)
|
||||||
if x.shape[1] > 3:
|
if unregularized:
|
||||||
# colorize with random projection
|
return z, dict()
|
||||||
assert xrec.shape[1] > 3
|
z, reg_log = self.regularization(z)
|
||||||
x = self.to_rgb(x)
|
if return_reg_log:
|
||||||
xrec = self.to_rgb(xrec)
|
return z, reg_log
|
||||||
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
return z
|
||||||
log["reconstructions"] = xrec
|
|
||||||
if log_ema or self.use_ema:
|
|
||||||
with self.ema_scope():
|
|
||||||
xrec_ema, posterior_ema = self(x)
|
|
||||||
if x.shape[1] > 3:
|
|
||||||
# colorize with random projection
|
|
||||||
assert xrec_ema.shape[1] > 3
|
|
||||||
xrec_ema = self.to_rgb(xrec_ema)
|
|
||||||
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
|
|
||||||
log["reconstructions_ema"] = xrec_ema
|
|
||||||
log["inputs"] = x
|
|
||||||
return log
|
|
||||||
|
|
||||||
def to_rgb(self, x):
|
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||||
assert self.image_key == "segmentation"
|
x = self.decoder(z, **kwargs)
|
||||||
if not hasattr(self, "colorize"):
|
|
||||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
|
||||||
x = F.conv2d(x, weight=self.colorize)
|
|
||||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, **additional_decode_kwargs
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||||
|
z, reg_log = self.encode(x, return_reg_log=True)
|
||||||
|
dec = self.decode(z, **additional_decode_kwargs)
|
||||||
|
return z, dec, reg_log
|
||||||
|
|
||||||
class IdentityFirstStage(torch.nn.Module):
|
|
||||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
|
||||||
self.vq_interface = vq_interface
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def encode(self, x, *args, **kwargs):
|
class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||||
return x
|
def __init__(self, embed_dim: int, **kwargs):
|
||||||
|
self.max_batch_size = kwargs.pop("max_batch_size", None)
|
||||||
|
ddconfig = kwargs.pop("ddconfig")
|
||||||
|
super().__init__(
|
||||||
|
encoder_config={
|
||||||
|
"target": "comfy.ldm.modules.diffusionmodules.model.Encoder",
|
||||||
|
"params": ddconfig,
|
||||||
|
},
|
||||||
|
decoder_config={
|
||||||
|
"target": "comfy.ldm.modules.diffusionmodules.model.Decoder",
|
||||||
|
"params": ddconfig,
|
||||||
|
},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self.quant_conv = torch.nn.Conv2d(
|
||||||
|
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
||||||
|
(1 + ddconfig["double_z"]) * embed_dim,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
def decode(self, x, *args, **kwargs):
|
def get_autoencoder_params(self) -> list:
|
||||||
return x
|
params = super().get_autoencoder_params()
|
||||||
|
return params
|
||||||
|
|
||||||
def quantize(self, x, *args, **kwargs):
|
def encode(
|
||||||
if self.vq_interface:
|
self, x: torch.Tensor, return_reg_log: bool = False
|
||||||
return x, None, [None, None, None]
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||||
return x
|
if self.max_batch_size is None:
|
||||||
|
z = self.encoder(x)
|
||||||
|
z = self.quant_conv(z)
|
||||||
|
else:
|
||||||
|
N = x.shape[0]
|
||||||
|
bs = self.max_batch_size
|
||||||
|
n_batches = int(math.ceil(N / bs))
|
||||||
|
z = list()
|
||||||
|
for i_batch in range(n_batches):
|
||||||
|
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
|
||||||
|
z_batch = self.quant_conv(z_batch)
|
||||||
|
z.append(z_batch)
|
||||||
|
z = torch.cat(z, 0)
|
||||||
|
|
||||||
def forward(self, x, *args, **kwargs):
|
z, reg_log = self.regularization(z)
|
||||||
return x
|
if return_reg_log:
|
||||||
|
return z, reg_log
|
||||||
|
return z
|
||||||
|
|
||||||
|
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
||||||
|
if self.max_batch_size is None:
|
||||||
|
dec = self.post_quant_conv(z)
|
||||||
|
dec = self.decoder(dec, **decoder_kwargs)
|
||||||
|
else:
|
||||||
|
N = z.shape[0]
|
||||||
|
bs = self.max_batch_size
|
||||||
|
n_batches = int(math.ceil(N / bs))
|
||||||
|
dec = list()
|
||||||
|
for i_batch in range(n_batches):
|
||||||
|
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
|
||||||
|
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
|
||||||
|
dec.append(dec_batch)
|
||||||
|
dec = torch.cat(dec, 0)
|
||||||
|
|
||||||
|
return dec
|
||||||
|
|
||||||
|
|
||||||
|
class AutoencoderKL(AutoencodingEngineLegacy):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
if "lossconfig" in kwargs:
|
||||||
|
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
||||||
|
super().__init__(
|
||||||
|
regularizer_config={
|
||||||
|
"target": (
|
||||||
|
"comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"
|
||||||
|
)
|
||||||
|
},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|||||||
@ -285,15 +285,14 @@ def attention_pytorch(q, k, v, heads, mask=None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
|
|
||||||
if exists(mask):
|
|
||||||
raise NotImplementedError
|
|
||||||
out = (
|
out = (
|
||||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
optimized_attention = attention_basic
|
optimized_attention = attention_basic
|
||||||
|
optimized_attention_masked = attention_basic
|
||||||
|
|
||||||
if model_management.xformers_enabled():
|
if model_management.xformers_enabled():
|
||||||
print("Using xformers cross attention")
|
print("Using xformers cross attention")
|
||||||
@ -309,6 +308,9 @@ else:
|
|||||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||||
optimized_attention = attention_sub_quad
|
optimized_attention = attention_sub_quad
|
||||||
|
|
||||||
|
if model_management.pytorch_attention_enabled():
|
||||||
|
optimized_attention_masked = attention_pytorch
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -334,7 +336,10 @@ class CrossAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
v = self.to_v(context)
|
v = self.to_v(context)
|
||||||
|
|
||||||
out = optimized_attention(q, k, v, self.heads, mask)
|
if mask is None:
|
||||||
|
out = optimized_attention(q, k, v, self.heads)
|
||||||
|
else:
|
||||||
|
out = optimized_attention_masked(q, k, v, self.heads, mask)
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -193,6 +193,52 @@ def slice_attention(q, k, v):
|
|||||||
|
|
||||||
return r1
|
return r1
|
||||||
|
|
||||||
|
def normal_attention(q, k, v):
|
||||||
|
# compute attention
|
||||||
|
b,c,h,w = q.shape
|
||||||
|
|
||||||
|
q = q.reshape(b,c,h*w)
|
||||||
|
q = q.permute(0,2,1) # b,hw,c
|
||||||
|
k = k.reshape(b,c,h*w) # b,c,hw
|
||||||
|
v = v.reshape(b,c,h*w)
|
||||||
|
|
||||||
|
r1 = slice_attention(q, k, v)
|
||||||
|
h_ = r1.reshape(b,c,h,w)
|
||||||
|
del r1
|
||||||
|
return h_
|
||||||
|
|
||||||
|
def xformers_attention(q, k, v):
|
||||||
|
# compute attention
|
||||||
|
B, C, H, W = q.shape
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||||
|
out = out.transpose(1, 2).reshape(B, C, H, W)
|
||||||
|
except NotImplementedError as e:
|
||||||
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def pytorch_attention(q, k, v):
|
||||||
|
# compute attention
|
||||||
|
B, C, H, W = q.shape
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||||
|
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||||
|
except model_management.OOM_EXCEPTION as e:
|
||||||
|
print("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||||
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
class AttnBlock(nn.Module):
|
||||||
def __init__(self, in_channels):
|
def __init__(self, in_channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -220,6 +266,16 @@ class AttnBlock(nn.Module):
|
|||||||
stride=1,
|
stride=1,
|
||||||
padding=0)
|
padding=0)
|
||||||
|
|
||||||
|
if model_management.xformers_enabled_vae():
|
||||||
|
print("Using xformers attention in VAE")
|
||||||
|
self.optimized_attention = xformers_attention
|
||||||
|
elif model_management.pytorch_attention_enabled():
|
||||||
|
print("Using pytorch attention in VAE")
|
||||||
|
self.optimized_attention = pytorch_attention
|
||||||
|
else:
|
||||||
|
print("Using split attention in VAE")
|
||||||
|
self.optimized_attention = normal_attention
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
@ -227,149 +283,15 @@ class AttnBlock(nn.Module):
|
|||||||
k = self.k(h_)
|
k = self.k(h_)
|
||||||
v = self.v(h_)
|
v = self.v(h_)
|
||||||
|
|
||||||
# compute attention
|
h_ = self.optimized_attention(q, k, v)
|
||||||
b,c,h,w = q.shape
|
|
||||||
|
|
||||||
q = q.reshape(b,c,h*w)
|
|
||||||
q = q.permute(0,2,1) # b,hw,c
|
|
||||||
k = k.reshape(b,c,h*w) # b,c,hw
|
|
||||||
v = v.reshape(b,c,h*w)
|
|
||||||
|
|
||||||
r1 = slice_attention(q, k, v)
|
|
||||||
h_ = r1.reshape(b,c,h,w)
|
|
||||||
del r1
|
|
||||||
h_ = self.proj_out(h_)
|
h_ = self.proj_out(h_)
|
||||||
|
|
||||||
return x+h_
|
return x+h_
|
||||||
|
|
||||||
class MemoryEfficientAttnBlock(nn.Module):
|
|
||||||
"""
|
|
||||||
Uses xformers efficient implementation,
|
|
||||||
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
|
||||||
Note: this is a single-head self-attention operation
|
|
||||||
"""
|
|
||||||
#
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
|
|
||||||
self.norm = Normalize(in_channels)
|
|
||||||
self.q = comfy.ops.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.k = comfy.ops.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.v = comfy.ops.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.proj_out = comfy.ops.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.attention_op: Optional[Any] = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
h_ = x
|
|
||||||
h_ = self.norm(h_)
|
|
||||||
q = self.q(h_)
|
|
||||||
k = self.k(h_)
|
|
||||||
v = self.v(h_)
|
|
||||||
|
|
||||||
# compute attention
|
|
||||||
B, C, H, W = q.shape
|
|
||||||
q, k, v = map(
|
|
||||||
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
|
||||||
(q, k, v),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
|
||||||
out = out.transpose(1, 2).reshape(B, C, H, W)
|
|
||||||
except NotImplementedError as e:
|
|
||||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
|
||||||
|
|
||||||
out = self.proj_out(out)
|
|
||||||
return x+out
|
|
||||||
|
|
||||||
class MemoryEfficientAttnBlockPytorch(nn.Module):
|
|
||||||
def __init__(self, in_channels):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
|
|
||||||
self.norm = Normalize(in_channels)
|
|
||||||
self.q = comfy.ops.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.k = comfy.ops.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.v = comfy.ops.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.proj_out = comfy.ops.Conv2d(in_channels,
|
|
||||||
in_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=1,
|
|
||||||
padding=0)
|
|
||||||
self.attention_op: Optional[Any] = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
h_ = x
|
|
||||||
h_ = self.norm(h_)
|
|
||||||
q = self.q(h_)
|
|
||||||
k = self.k(h_)
|
|
||||||
v = self.v(h_)
|
|
||||||
|
|
||||||
# compute attention
|
|
||||||
B, C, H, W = q.shape
|
|
||||||
q, k, v = map(
|
|
||||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
|
||||||
(q, k, v),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
||||||
out = out.transpose(2, 3).reshape(B, C, H, W)
|
|
||||||
except model_management.OOM_EXCEPTION as e:
|
|
||||||
print("scaled_dot_product_attention OOMed: switched to slice attention")
|
|
||||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
|
||||||
|
|
||||||
out = self.proj_out(out)
|
|
||||||
return x+out
|
|
||||||
|
|
||||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||||
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
|
return AttnBlock(in_channels)
|
||||||
if model_management.xformers_enabled_vae() and attn_type == "vanilla":
|
|
||||||
attn_type = "vanilla-xformers"
|
|
||||||
elif model_management.pytorch_attention_enabled() and attn_type == "vanilla":
|
|
||||||
attn_type = "vanilla-pytorch"
|
|
||||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
|
||||||
if attn_type == "vanilla":
|
|
||||||
assert attn_kwargs is None
|
|
||||||
return AttnBlock(in_channels)
|
|
||||||
elif attn_type == "vanilla-xformers":
|
|
||||||
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
|
||||||
return MemoryEfficientAttnBlock(in_channels)
|
|
||||||
elif attn_type == "vanilla-pytorch":
|
|
||||||
return MemoryEfficientAttnBlockPytorch(in_channels)
|
|
||||||
elif attn_type == "none":
|
|
||||||
return nn.Identity(in_channels)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -619,7 +541,10 @@ class Decoder(nn.Module):
|
|||||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
||||||
attn_type="vanilla", **ignorekwargs):
|
conv_out_op=comfy.ops.Conv2d,
|
||||||
|
resnet_op=ResnetBlock,
|
||||||
|
attn_op=AttnBlock,
|
||||||
|
**ignorekwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if use_linear_attn: attn_type = "linear"
|
if use_linear_attn: attn_type = "linear"
|
||||||
self.ch = ch
|
self.ch = ch
|
||||||
@ -648,12 +573,12 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
# middle
|
# middle
|
||||||
self.mid = nn.Module()
|
self.mid = nn.Module()
|
||||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
self.mid.block_1 = resnet_op(in_channels=block_in,
|
||||||
out_channels=block_in,
|
out_channels=block_in,
|
||||||
temb_channels=self.temb_ch,
|
temb_channels=self.temb_ch,
|
||||||
dropout=dropout)
|
dropout=dropout)
|
||||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
self.mid.attn_1 = attn_op(block_in)
|
||||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
self.mid.block_2 = resnet_op(in_channels=block_in,
|
||||||
out_channels=block_in,
|
out_channels=block_in,
|
||||||
temb_channels=self.temb_ch,
|
temb_channels=self.temb_ch,
|
||||||
dropout=dropout)
|
dropout=dropout)
|
||||||
@ -665,13 +590,13 @@ class Decoder(nn.Module):
|
|||||||
attn = nn.ModuleList()
|
attn = nn.ModuleList()
|
||||||
block_out = ch*ch_mult[i_level]
|
block_out = ch*ch_mult[i_level]
|
||||||
for i_block in range(self.num_res_blocks+1):
|
for i_block in range(self.num_res_blocks+1):
|
||||||
block.append(ResnetBlock(in_channels=block_in,
|
block.append(resnet_op(in_channels=block_in,
|
||||||
out_channels=block_out,
|
out_channels=block_out,
|
||||||
temb_channels=self.temb_ch,
|
temb_channels=self.temb_ch,
|
||||||
dropout=dropout))
|
dropout=dropout))
|
||||||
block_in = block_out
|
block_in = block_out
|
||||||
if curr_res in attn_resolutions:
|
if curr_res in attn_resolutions:
|
||||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
attn.append(attn_op(block_in))
|
||||||
up = nn.Module()
|
up = nn.Module()
|
||||||
up.block = block
|
up.block = block
|
||||||
up.attn = attn
|
up.attn = attn
|
||||||
@ -682,13 +607,13 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
# end
|
# end
|
||||||
self.norm_out = Normalize(block_in)
|
self.norm_out = Normalize(block_in)
|
||||||
self.conv_out = comfy.ops.Conv2d(block_in,
|
self.conv_out = conv_out_op(block_in,
|
||||||
out_ch,
|
out_ch,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=1)
|
padding=1)
|
||||||
|
|
||||||
def forward(self, z):
|
def forward(self, z, **kwargs):
|
||||||
#assert z.shape[1:] == self.z_shape[1:]
|
#assert z.shape[1:] == self.z_shape[1:]
|
||||||
self.last_z_shape = z.shape
|
self.last_z_shape = z.shape
|
||||||
|
|
||||||
@ -699,16 +624,16 @@ class Decoder(nn.Module):
|
|||||||
h = self.conv_in(z)
|
h = self.conv_in(z)
|
||||||
|
|
||||||
# middle
|
# middle
|
||||||
h = self.mid.block_1(h, temb)
|
h = self.mid.block_1(h, temb, **kwargs)
|
||||||
h = self.mid.attn_1(h)
|
h = self.mid.attn_1(h, **kwargs)
|
||||||
h = self.mid.block_2(h, temb)
|
h = self.mid.block_2(h, temb, **kwargs)
|
||||||
|
|
||||||
# upsampling
|
# upsampling
|
||||||
for i_level in reversed(range(self.num_resolutions)):
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
for i_block in range(self.num_res_blocks+1):
|
for i_block in range(self.num_res_blocks+1):
|
||||||
h = self.up[i_level].block[i_block](h, temb)
|
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
||||||
if len(self.up[i_level].attn) > 0:
|
if len(self.up[i_level].attn) > 0:
|
||||||
h = self.up[i_level].attn[i_block](h)
|
h = self.up[i_level].attn[i_block](h, **kwargs)
|
||||||
if i_level != 0:
|
if i_level != 0:
|
||||||
h = self.up[i_level].upsample(h)
|
h = self.up[i_level].upsample(h)
|
||||||
|
|
||||||
@ -718,7 +643,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
h = self.norm_out(h)
|
h = self.norm_out(h)
|
||||||
h = nonlinearity(h)
|
h = nonlinearity(h)
|
||||||
h = self.conv_out(h)
|
h = self.conv_out(h, **kwargs)
|
||||||
if self.tanh_out:
|
if self.tanh_out:
|
||||||
h = torch.tanh(h)
|
h = torch.tanh(h)
|
||||||
return h
|
return h
|
||||||
|
|||||||
@ -26,6 +26,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||||
if self.adm_channels is None:
|
if self.adm_channels is None:
|
||||||
self.adm_channels = 0
|
self.adm_channels = 0
|
||||||
|
self.inpaint_model = False
|
||||||
print("model_type", model_type.name)
|
print("model_type", model_type.name)
|
||||||
print("adm", self.adm_channels)
|
print("adm", self.adm_channels)
|
||||||
|
|
||||||
@ -71,6 +72,38 @@ class BaseModel(torch.nn.Module):
|
|||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def cond_concat(self, **kwargs):
|
||||||
|
if self.inpaint_model:
|
||||||
|
concat_keys = ("mask", "masked_image")
|
||||||
|
cond_concat = []
|
||||||
|
denoise_mask = kwargs.get("denoise_mask", None)
|
||||||
|
latent_image = kwargs.get("latent_image", None)
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
device = kwargs["device"]
|
||||||
|
|
||||||
|
def blank_inpaint_image_like(latent_image):
|
||||||
|
blank_image = torch.ones_like(latent_image)
|
||||||
|
# these are the values for "zero" in pixel space translated to latent space
|
||||||
|
blank_image[:,0] *= 0.8223
|
||||||
|
blank_image[:,1] *= -0.6876
|
||||||
|
blank_image[:,2] *= 0.6364
|
||||||
|
blank_image[:,3] *= 0.1380
|
||||||
|
return blank_image
|
||||||
|
|
||||||
|
for ck in concat_keys:
|
||||||
|
if denoise_mask is not None:
|
||||||
|
if ck == "mask":
|
||||||
|
cond_concat.append(denoise_mask[:,:1].to(device))
|
||||||
|
elif ck == "masked_image":
|
||||||
|
cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
|
||||||
|
else:
|
||||||
|
if ck == "mask":
|
||||||
|
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||||
|
elif ck == "masked_image":
|
||||||
|
cond_concat.append(blank_inpaint_image_like(noise))
|
||||||
|
return cond_concat
|
||||||
|
return None
|
||||||
|
|
||||||
def load_model_weights(self, sd, unet_prefix=""):
|
def load_model_weights(self, sd, unet_prefix=""):
|
||||||
to_load = {}
|
to_load = {}
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
@ -112,7 +145,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
|
return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
|
||||||
|
|
||||||
def set_inpaint(self):
|
def set_inpaint(self):
|
||||||
self.concat_keys = ("mask", "masked_image")
|
self.inpaint_model = True
|
||||||
|
|
||||||
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
|
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
|
||||||
adm_inputs = []
|
adm_inputs = []
|
||||||
|
|||||||
@ -667,7 +667,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
#FP16 is just broken on these cards
|
#FP16 is just broken on these cards
|
||||||
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX"]
|
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"]
|
||||||
for x in nvidia_16_series:
|
for x in nvidia_16_series:
|
||||||
if x in props.name:
|
if x in props.name:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -98,6 +98,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
|||||||
samples = samples.cpu()
|
samples = samples.cpu()
|
||||||
|
|
||||||
cleanup_additional_models(models)
|
cleanup_additional_models(models)
|
||||||
|
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
@ -109,5 +110,6 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent
|
|||||||
samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||||
samples = samples.cpu()
|
samples = samples.cpu()
|
||||||
cleanup_additional_models(models)
|
cleanup_additional_models(models)
|
||||||
|
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|||||||
@ -14,8 +14,8 @@ def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
|||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns predicted noise
|
#Returns predicted noise
|
||||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}, seed=None):
|
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||||
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
|
def get_area_and_mult(cond, x_in, timestep_in):
|
||||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||||
strength = 1.0
|
strength = 1.0
|
||||||
if 'timestep_start' in cond[1]:
|
if 'timestep_start' in cond[1]:
|
||||||
@ -68,12 +68,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
|
|
||||||
conditionning = {}
|
conditionning = {}
|
||||||
conditionning['c_crossattn'] = cond[0]
|
conditionning['c_crossattn'] = cond[0]
|
||||||
if cond_concat_in is not None and len(cond_concat_in) > 0:
|
|
||||||
cropped = []
|
if 'concat' in cond[1]:
|
||||||
for x in cond_concat_in:
|
cond_concat_in = cond[1]['concat']
|
||||||
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
if cond_concat_in is not None and len(cond_concat_in) > 0:
|
||||||
cropped.append(cr)
|
cropped = []
|
||||||
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
for x in cond_concat_in:
|
||||||
|
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||||
|
cropped.append(cr)
|
||||||
|
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
||||||
|
|
||||||
if adm_cond is not None:
|
if adm_cond is not None:
|
||||||
conditionning['c_adm'] = adm_cond
|
conditionning['c_adm'] = adm_cond
|
||||||
@ -173,7 +176,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
out['c_adm'] = torch.cat(c_adm)
|
out['c_adm'] = torch.cat(c_adm)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options):
|
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
|
||||||
out_cond = torch.zeros_like(x_in)
|
out_cond = torch.zeros_like(x_in)
|
||||||
out_count = torch.ones_like(x_in)/100000.0
|
out_count = torch.ones_like(x_in)/100000.0
|
||||||
|
|
||||||
@ -185,14 +188,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
|
|
||||||
to_run = []
|
to_run = []
|
||||||
for x in cond:
|
for x in cond:
|
||||||
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
|
p = get_area_and_mult(x, x_in, timestep)
|
||||||
if p is None:
|
if p is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
to_run += [(p, COND)]
|
to_run += [(p, COND)]
|
||||||
if uncond is not None:
|
if uncond is not None:
|
||||||
for x in uncond:
|
for x in uncond:
|
||||||
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
|
p = get_area_and_mult(x, x_in, timestep)
|
||||||
if p is None:
|
if p is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -286,7 +289,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
|||||||
if math.isclose(cond_scale, 1.0):
|
if math.isclose(cond_scale, 1.0):
|
||||||
uncond = None
|
uncond = None
|
||||||
|
|
||||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
|
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options)
|
||||||
if "sampler_cfg_function" in model_options:
|
if "sampler_cfg_function" in model_options:
|
||||||
args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep}
|
args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep}
|
||||||
return model_options["sampler_cfg_function"](args)
|
return model_options["sampler_cfg_function"](args)
|
||||||
@ -307,8 +310,8 @@ class CFGNoisePredictor(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
self.alphas_cumprod = model.alphas_cumprod
|
self.alphas_cumprod = model.alphas_cumprod
|
||||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, seed=None):
|
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
|
||||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options, seed=seed)
|
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -316,11 +319,11 @@ class KSamplerX0Inpaint(torch.nn.Module):
|
|||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}, seed=None):
|
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
latent_mask = 1. - denoise_mask
|
latent_mask = 1. - denoise_mask
|
||||||
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
|
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
|
||||||
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options, seed=seed)
|
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
out *= denoise_mask
|
out *= denoise_mask
|
||||||
|
|
||||||
@ -358,15 +361,6 @@ def sgm_scheduler(model, steps):
|
|||||||
sigs += [0.0]
|
sigs += [0.0]
|
||||||
return torch.FloatTensor(sigs)
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
def blank_inpaint_image_like(latent_image):
|
|
||||||
blank_image = torch.ones_like(latent_image)
|
|
||||||
# these are the values for "zero" in pixel space translated to latent space
|
|
||||||
blank_image[:,0] *= 0.8223
|
|
||||||
blank_image[:,1] *= -0.6876
|
|
||||||
blank_image[:,2] *= 0.6364
|
|
||||||
blank_image[:,3] *= 0.1380
|
|
||||||
return blank_image
|
|
||||||
|
|
||||||
def get_mask_aabb(masks):
|
def get_mask_aabb(masks):
|
||||||
if masks.numel() == 0:
|
if masks.numel() == 0:
|
||||||
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
|
||||||
@ -543,6 +537,20 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
|
|||||||
|
|
||||||
return conds
|
return conds
|
||||||
|
|
||||||
|
def encode_cond(model_function, key, conds, device, **kwargs):
|
||||||
|
for t in range(len(conds)):
|
||||||
|
x = conds[t]
|
||||||
|
params = x[1].copy()
|
||||||
|
params["device"] = device
|
||||||
|
for k in kwargs:
|
||||||
|
if k not in params:
|
||||||
|
params[k] = kwargs[k]
|
||||||
|
|
||||||
|
out = model_function(**params)
|
||||||
|
if out is not None:
|
||||||
|
x[1] = x[1].copy()
|
||||||
|
x[1][key] = out
|
||||||
|
return conds
|
||||||
|
|
||||||
class Sampler:
|
class Sampler:
|
||||||
def sample(self):
|
def sample(self):
|
||||||
@ -662,31 +670,19 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||||||
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||||
|
|
||||||
|
if latent_image is not None:
|
||||||
|
latent_image = model.process_latent_in(latent_image)
|
||||||
|
|
||||||
if model.is_adm():
|
if model.is_adm():
|
||||||
positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
|
positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
|
||||||
negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
|
negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
|
||||||
|
|
||||||
if latent_image is not None:
|
if hasattr(model, 'cond_concat'):
|
||||||
latent_image = model.process_latent_in(latent_image)
|
positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
|
||||||
|
negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
|
||||||
|
|
||||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
||||||
|
|
||||||
cond_concat = None
|
|
||||||
if hasattr(model, 'concat_keys'): #inpaint
|
|
||||||
cond_concat = []
|
|
||||||
for ck in model.concat_keys:
|
|
||||||
if denoise_mask is not None:
|
|
||||||
if ck == "mask":
|
|
||||||
cond_concat.append(denoise_mask[:,:1])
|
|
||||||
elif ck == "masked_image":
|
|
||||||
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
|
|
||||||
else:
|
|
||||||
if ck == "mask":
|
|
||||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
|
||||||
elif ck == "masked_image":
|
|
||||||
cond_concat.append(blank_inpaint_image_like(noise))
|
|
||||||
extra_args["cond_concat"] = cond_concat
|
|
||||||
|
|
||||||
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||||
return model.process_latent_out(samples.to(torch.float32))
|
return model.process_latent_out(samples.to(torch.float32))
|
||||||
|
|
||||||
@ -743,7 +739,7 @@ class KSampler:
|
|||||||
sigmas = None
|
sigmas = None
|
||||||
|
|
||||||
discard_penultimate_sigma = False
|
discard_penultimate_sigma = False
|
||||||
if self.sampler in ['dpm_2', 'dpm_2_ancestral']:
|
if self.sampler in ['dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2']:
|
||||||
steps += 1
|
steps += 1
|
||||||
discard_penultimate_sigma = True
|
discard_penultimate_sigma = True
|
||||||
|
|
||||||
|
|||||||
48
comfy/sd.py
48
comfy/sd.py
@ -4,7 +4,7 @@ import math
|
|||||||
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
from .ldm.util import instantiate_from_config
|
from .ldm.util import instantiate_from_config
|
||||||
from .ldm.models.autoencoder import AutoencoderKL
|
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@ -140,21 +140,24 @@ class CLIP:
|
|||||||
return self.patcher.get_key_patches()
|
return self.patcher.get_key_patches()
|
||||||
|
|
||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, ckpt_path=None, device=None, config=None):
|
def __init__(self, sd=None, device=None, config=None):
|
||||||
|
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||||
|
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
#default SD1.x/SD2.x VAE parameters
|
#default SD1.x/SD2.x VAE parameters
|
||||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss")
|
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
||||||
else:
|
else:
|
||||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||||
self.first_stage_model = self.first_stage_model.eval()
|
self.first_stage_model = self.first_stage_model.eval()
|
||||||
if ckpt_path is not None:
|
|
||||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
||||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
if len(m) > 0:
|
||||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
print("Missing VAE keys", m)
|
||||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
|
||||||
if len(m) > 0:
|
if len(u) > 0:
|
||||||
print("Missing VAE keys", m)
|
print("Leftover VAE keys", u)
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = model_management.vae_device()
|
device = model_management.vae_device()
|
||||||
@ -183,7 +186,7 @@ class VAE:
|
|||||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).sample().float()
|
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
|
||||||
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
||||||
@ -229,7 +232,7 @@ class VAE:
|
|||||||
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
|
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
|
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
|
||||||
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu().float()
|
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
|
||||||
|
|
||||||
except model_management.OOM_EXCEPTION as e:
|
except model_management.OOM_EXCEPTION as e:
|
||||||
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
@ -375,10 +378,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
model.load_model_weights(state_dict, "model.diffusion_model.")
|
model.load_model_weights(state_dict, "model.diffusion_model.")
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
w = WeightsLoader()
|
vae_sd = comfy.utils.state_dict_prefix_replace(state_dict, {"first_stage_model.": ""}, filter_keys=True)
|
||||||
vae = VAE(config=vae_config)
|
vae = VAE(sd=vae_sd, config=vae_config)
|
||||||
w.first_stage_model = vae.first_stage_model
|
|
||||||
load_model_weights(w, state_dict)
|
|
||||||
|
|
||||||
if output_clip:
|
if output_clip:
|
||||||
w = WeightsLoader()
|
w = WeightsLoader()
|
||||||
@ -427,18 +428,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
model.load_model_weights(sd, "model.diffusion_model.")
|
model.load_model_weights(sd, "model.diffusion_model.")
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
vae = VAE()
|
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
|
||||||
w = WeightsLoader()
|
vae = VAE(sd=vae_sd)
|
||||||
w.first_stage_model = vae.first_stage_model
|
|
||||||
load_model_weights(w, sd)
|
|
||||||
|
|
||||||
if output_clip:
|
if output_clip:
|
||||||
w = WeightsLoader()
|
w = WeightsLoader()
|
||||||
clip_target = model_config.clip_target()
|
clip_target = model_config.clip_target()
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
if clip_target is not None:
|
||||||
w.cond_stage_model = clip.cond_stage_model
|
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||||
sd = model_config.process_clip_state_dict(sd)
|
w.cond_stage_model = clip.cond_stage_model
|
||||||
load_model_weights(w, sd)
|
sd = model_config.process_clip_state_dict(sd)
|
||||||
|
load_model_weights(w, sd)
|
||||||
|
|
||||||
left_over = sd.keys()
|
left_over = sd.keys()
|
||||||
if len(left_over) > 0:
|
if len(left_over) > 0:
|
||||||
|
|||||||
@ -47,12 +47,17 @@ def state_dict_key_replace(state_dict, keys_to_replace):
|
|||||||
state_dict[keys_to_replace[x]] = state_dict.pop(x)
|
state_dict[keys_to_replace[x]] = state_dict.pop(x)
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def state_dict_prefix_replace(state_dict, replace_prefix):
|
def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
|
||||||
|
if filter_keys:
|
||||||
|
out = {}
|
||||||
|
else:
|
||||||
|
out = state_dict
|
||||||
for rp in replace_prefix:
|
for rp in replace_prefix:
|
||||||
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
|
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
|
||||||
for x in replace:
|
for x in replace:
|
||||||
state_dict[x[1]] = state_dict.pop(x[0])
|
w = state_dict.pop(x[0])
|
||||||
return state_dict
|
out[x[1]] = w
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||||
|
|||||||
@ -61,7 +61,53 @@ class FreeU:
|
|||||||
m.set_model_output_block_patch(output_block_patch)
|
m.set_model_output_block_patch(output_block_patch)
|
||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
|
class FreeU_V2:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def patch(self, model, b1, b2, s1, s2):
|
||||||
|
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||||
|
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
||||||
|
on_cpu_devices = {}
|
||||||
|
|
||||||
|
def output_block_patch(h, hsp, transformer_options):
|
||||||
|
scale = scale_dict.get(h.shape[1], None)
|
||||||
|
if scale is not None:
|
||||||
|
hidden_mean = h.mean(1).unsqueeze(1)
|
||||||
|
B = hidden_mean.shape[0]
|
||||||
|
hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
||||||
|
hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
|
||||||
|
hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
|
||||||
|
|
||||||
|
h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * ((scale[0] - 1 ) * hidden_mean + 1)
|
||||||
|
|
||||||
|
if hsp.device not in on_cpu_devices:
|
||||||
|
try:
|
||||||
|
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
|
||||||
|
except:
|
||||||
|
print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.")
|
||||||
|
on_cpu_devices[hsp.device] = True
|
||||||
|
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
||||||
|
else:
|
||||||
|
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
|
||||||
|
|
||||||
|
return h, hsp
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_output_block_patch(output_block_patch)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"FreeU": FreeU,
|
"FreeU": FreeU,
|
||||||
|
"FreeU_V2": FreeU_V2,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -19,6 +19,7 @@ def load_hypernetwork_patch(path, strength):
|
|||||||
"tanh": torch.nn.Tanh,
|
"tanh": torch.nn.Tanh,
|
||||||
"sigmoid": torch.nn.Sigmoid,
|
"sigmoid": torch.nn.Sigmoid,
|
||||||
"softsign": torch.nn.Softsign,
|
"softsign": torch.nn.Softsign,
|
||||||
|
"mish": torch.nn.Mish,
|
||||||
}
|
}
|
||||||
|
|
||||||
if activation_func not in valid_activation:
|
if activation_func not in valid_activation:
|
||||||
@ -42,7 +43,8 @@ def load_hypernetwork_patch(path, strength):
|
|||||||
linears = list(map(lambda a: a[:-len(".weight")], linears))
|
linears = list(map(lambda a: a[:-len(".weight")], linears))
|
||||||
layers = []
|
layers = []
|
||||||
|
|
||||||
for i in range(len(linears)):
|
i = 0
|
||||||
|
while i < len(linears):
|
||||||
lin_name = linears[i]
|
lin_name = linears[i]
|
||||||
last_layer = (i == (len(linears) - 1))
|
last_layer = (i == (len(linears) - 1))
|
||||||
penultimate_layer = (i == (len(linears) - 2))
|
penultimate_layer = (i == (len(linears) - 2))
|
||||||
@ -56,10 +58,17 @@ def load_hypernetwork_patch(path, strength):
|
|||||||
if (not last_layer) or (activate_output):
|
if (not last_layer) or (activate_output):
|
||||||
layers.append(valid_activation[activation_func]())
|
layers.append(valid_activation[activation_func]())
|
||||||
if is_layer_norm:
|
if is_layer_norm:
|
||||||
layers.append(torch.nn.LayerNorm(lin_weight.shape[0]))
|
i += 1
|
||||||
|
ln_name = linears[i]
|
||||||
|
ln_weight = attn_weights['{}.weight'.format(ln_name)]
|
||||||
|
ln_bias = attn_weights['{}.bias'.format(ln_name)]
|
||||||
|
ln = torch.nn.LayerNorm(ln_weight.shape[0])
|
||||||
|
ln.load_state_dict({"weight": ln_weight, "bias": ln_bias})
|
||||||
|
layers.append(ln)
|
||||||
if use_dropout:
|
if use_dropout:
|
||||||
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
|
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
|
||||||
layers.append(torch.nn.Dropout(p=0.3))
|
layers.append(torch.nn.Dropout(p=0.3))
|
||||||
|
i += 1
|
||||||
|
|
||||||
output.append(torch.nn.Sequential(*layers))
|
output.append(torch.nn.Sequential(*layers))
|
||||||
out[dim] = torch.nn.ModuleList(output)
|
out[dim] = torch.nn.ModuleList(output)
|
||||||
|
|||||||
83
comfy_extras/nodes_hypertile.py
Normal file
83
comfy_extras/nodes_hypertile.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
#Taken from: https://github.com/tfernd/HyperTile/
|
||||||
|
|
||||||
|
import math
|
||||||
|
from einops import rearrange
|
||||||
|
import random
|
||||||
|
|
||||||
|
def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int:
|
||||||
|
min_value = min(min_value, value)
|
||||||
|
|
||||||
|
# All big divisors of value (inclusive)
|
||||||
|
divisors = [i for i in range(min_value, value + 1) if value % i == 0]
|
||||||
|
|
||||||
|
ns = [value // i for i in divisors[:max_options]] # has at least 1 element
|
||||||
|
|
||||||
|
random.seed(counter)
|
||||||
|
idx = random.randint(0, len(ns) - 1)
|
||||||
|
|
||||||
|
return ns[idx]
|
||||||
|
|
||||||
|
class HyperTile:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}),
|
||||||
|
"swap_size": ("INT", {"default": 2, "min": 1, "max": 128}),
|
||||||
|
"max_depth": ("INT", {"default": 0, "min": 0, "max": 10}),
|
||||||
|
"scale_depth": ("BOOLEAN", {"default": False}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
|
def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
|
||||||
|
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||||
|
|
||||||
|
apply_to = set()
|
||||||
|
temp = model_channels
|
||||||
|
for x in range(max_depth + 1):
|
||||||
|
apply_to.add(temp)
|
||||||
|
temp *= 2
|
||||||
|
|
||||||
|
latent_tile_size = max(32, tile_size) // 8
|
||||||
|
self.temp = None
|
||||||
|
self.counter = 1
|
||||||
|
|
||||||
|
def hypertile_in(q, k, v, extra_options):
|
||||||
|
if q.shape[-1] in apply_to:
|
||||||
|
shape = extra_options["original_shape"]
|
||||||
|
aspect_ratio = shape[-1] / shape[-2]
|
||||||
|
|
||||||
|
hw = q.size(1)
|
||||||
|
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
|
||||||
|
|
||||||
|
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1
|
||||||
|
nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter)
|
||||||
|
self.counter += 1
|
||||||
|
nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter)
|
||||||
|
self.counter += 1
|
||||||
|
|
||||||
|
if nh * nw > 1:
|
||||||
|
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
||||||
|
self.temp = (nh, nw, h, w)
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
return q, k, v
|
||||||
|
def hypertile_out(out, extra_options):
|
||||||
|
if self.temp is not None:
|
||||||
|
nh, nw, h, w = self.temp
|
||||||
|
self.temp = None
|
||||||
|
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
|
||||||
|
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_attn1_patch(hypertile_in)
|
||||||
|
m.set_model_attn1_output_patch(hypertile_out)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"HyperTile": HyperTile,
|
||||||
|
}
|
||||||
8
nodes.py
8
nodes.py
@ -584,7 +584,8 @@ class VAELoader:
|
|||||||
#TODO: scale factor?
|
#TODO: scale factor?
|
||||||
def load_vae(self, vae_name):
|
def load_vae(self, vae_name):
|
||||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
vae_path = folder_paths.get_full_path("vae", vae_name)
|
||||||
vae = comfy.sd.VAE(ckpt_path=vae_path)
|
sd = comfy.utils.load_torch_file(vae_path)
|
||||||
|
vae = comfy.sd.VAE(sd=sd)
|
||||||
return (vae,)
|
return (vae,)
|
||||||
|
|
||||||
class ControlNetLoader:
|
class ControlNetLoader:
|
||||||
@ -1660,7 +1661,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"KSampler": "KSampler",
|
"KSampler": "KSampler",
|
||||||
"KSamplerAdvanced": "KSampler (Advanced)",
|
"KSamplerAdvanced": "KSampler (Advanced)",
|
||||||
# Loaders
|
# Loaders
|
||||||
"CheckpointLoader": "Load Checkpoint (With Config)",
|
"CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
|
||||||
"CheckpointLoaderSimple": "Load Checkpoint",
|
"CheckpointLoaderSimple": "Load Checkpoint",
|
||||||
"VAELoader": "Load VAE",
|
"VAELoader": "Load VAE",
|
||||||
"LoraLoader": "Load LoRA",
|
"LoraLoader": "Load LoRA",
|
||||||
@ -1795,7 +1796,8 @@ def init_custom_nodes():
|
|||||||
"nodes_clip_sdxl.py",
|
"nodes_clip_sdxl.py",
|
||||||
"nodes_canny.py",
|
"nodes_canny.py",
|
||||||
"nodes_freelunch.py",
|
"nodes_freelunch.py",
|
||||||
"nodes_custom_sampler.py"
|
"nodes_custom_sampler.py",
|
||||||
|
"nodes_hypertile.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
for node_file in extras_files:
|
for node_file in extras_files:
|
||||||
|
|||||||
@ -47,7 +47,7 @@
|
|||||||
" !git pull\n",
|
" !git pull\n",
|
||||||
"\n",
|
"\n",
|
||||||
"!echo -= Install dependencies =-\n",
|
"!echo -= Install dependencies =-\n",
|
||||||
"!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117"
|
"!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
6
tests-ui/package-lock.json
generated
6
tests-ui/package-lock.json
generated
@ -1816,9 +1816,9 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/@babel/traverse": {
|
"node_modules/@babel/traverse": {
|
||||||
"version": "7.23.0",
|
"version": "7.23.2",
|
||||||
"resolved": "https://registry.npmjs.org/@babel/traverse/-/traverse-7.23.0.tgz",
|
"resolved": "https://registry.npmjs.org/@babel/traverse/-/traverse-7.23.2.tgz",
|
||||||
"integrity": "sha512-t/QaEvyIoIkwzpiZ7aoSKK8kObQYeF7T2v+dazAYCb8SXtp58zEVkWW7zAnju8FNKNdr4ScAOEDmMItbyOmEYw==",
|
"integrity": "sha512-azpe59SQ48qG6nu2CzcMLbxUudtN+dOM9kDbUqGq3HXUJRlo7i8fvPoxQUzYgLZ4cMVmuZgm8vvBpNeRhd6XSw==",
|
||||||
"dev": true,
|
"dev": true,
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@babel/code-frame": "^7.22.13",
|
"@babel/code-frame": "^7.22.13",
|
||||||
|
|||||||
@ -4,7 +4,8 @@
|
|||||||
"description": "UI tests",
|
"description": "UI tests",
|
||||||
"main": "index.js",
|
"main": "index.js",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"test": "jest"
|
"test": "jest",
|
||||||
|
"test:generate": "node setup.js"
|
||||||
},
|
},
|
||||||
"repository": {
|
"repository": {
|
||||||
"type": "git",
|
"type": "git",
|
||||||
|
|||||||
87
tests-ui/setup.js
Normal file
87
tests-ui/setup.js
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
const { spawn } = require("child_process");
|
||||||
|
const { resolve } = require("path");
|
||||||
|
const { existsSync, mkdirSync, writeFileSync } = require("fs");
|
||||||
|
const http = require("http");
|
||||||
|
|
||||||
|
async function setup() {
|
||||||
|
// Wait up to 30s for it to start
|
||||||
|
let success = false;
|
||||||
|
let child;
|
||||||
|
for (let i = 0; i < 30; i++) {
|
||||||
|
try {
|
||||||
|
await new Promise((res, rej) => {
|
||||||
|
http
|
||||||
|
.get("http://127.0.0.1:8188/object_info", (resp) => {
|
||||||
|
let data = "";
|
||||||
|
resp.on("data", (chunk) => {
|
||||||
|
data += chunk;
|
||||||
|
});
|
||||||
|
resp.on("end", () => {
|
||||||
|
// Modify the response data to add some checkpoints
|
||||||
|
const objectInfo = JSON.parse(data);
|
||||||
|
objectInfo.CheckpointLoaderSimple.input.required.ckpt_name[0] = ["model1.safetensors", "model2.ckpt"];
|
||||||
|
|
||||||
|
data = JSON.stringify(objectInfo, undefined, "\t");
|
||||||
|
|
||||||
|
const outDir = resolve("./data");
|
||||||
|
if (!existsSync(outDir)) {
|
||||||
|
mkdirSync(outDir);
|
||||||
|
}
|
||||||
|
|
||||||
|
const outPath = resolve(outDir, "object_info.json");
|
||||||
|
console.log(`Writing ${Object.keys(objectInfo).length} nodes to ${outPath}`);
|
||||||
|
writeFileSync(outPath, data, {
|
||||||
|
encoding: "utf8",
|
||||||
|
});
|
||||||
|
res();
|
||||||
|
});
|
||||||
|
})
|
||||||
|
.on("error", rej);
|
||||||
|
});
|
||||||
|
success = true;
|
||||||
|
break;
|
||||||
|
} catch (error) {
|
||||||
|
console.log(i + "/30", error);
|
||||||
|
if (i === 0) {
|
||||||
|
// Start the server on first iteration if it fails to connect
|
||||||
|
console.log("Starting ComfyUI server...");
|
||||||
|
|
||||||
|
let python = resolve("../../python_embeded/python.exe");
|
||||||
|
let args;
|
||||||
|
let cwd;
|
||||||
|
if (existsSync(python)) {
|
||||||
|
args = ["-s", "ComfyUI/main.py"];
|
||||||
|
cwd = "../..";
|
||||||
|
} else {
|
||||||
|
python = "python";
|
||||||
|
args = ["main.py"];
|
||||||
|
cwd = "..";
|
||||||
|
}
|
||||||
|
args.push("--cpu");
|
||||||
|
console.log(python, ...args);
|
||||||
|
child = spawn(python, args, { cwd });
|
||||||
|
child.on("error", (err) => {
|
||||||
|
console.log(`Server error (${err})`);
|
||||||
|
i = 30;
|
||||||
|
});
|
||||||
|
child.on("exit", (code) => {
|
||||||
|
if (!success) {
|
||||||
|
console.log(`Server exited (${code})`);
|
||||||
|
i = 30;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
await new Promise((r) => {
|
||||||
|
setTimeout(r, 1000);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
child?.kill();
|
||||||
|
|
||||||
|
if (!success) {
|
||||||
|
throw new Error("Waiting for server failed...");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
setup();
|
||||||
@ -5,6 +5,61 @@ function setNodeMode(node, mode) {
|
|||||||
node.graph.change();
|
node.graph.change();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function addNodesToGroup(group, nodes=[]) {
|
||||||
|
var x1, y1, x2, y2;
|
||||||
|
var nx1, ny1, nx2, ny2;
|
||||||
|
var node;
|
||||||
|
|
||||||
|
x1 = y1 = x2 = y2 = -1;
|
||||||
|
nx1 = ny1 = nx2 = ny2 = -1;
|
||||||
|
|
||||||
|
for (var n of [group._nodes, nodes]) {
|
||||||
|
for (var i in n) {
|
||||||
|
node = n[i]
|
||||||
|
|
||||||
|
nx1 = node.pos[0]
|
||||||
|
ny1 = node.pos[1]
|
||||||
|
nx2 = node.pos[0] + node.size[0]
|
||||||
|
ny2 = node.pos[1] + node.size[1]
|
||||||
|
|
||||||
|
if (node.type != "Reroute") {
|
||||||
|
ny1 -= LiteGraph.NODE_TITLE_HEIGHT;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node.flags?.collapsed) {
|
||||||
|
ny2 = ny1 + LiteGraph.NODE_TITLE_HEIGHT;
|
||||||
|
|
||||||
|
if (node?._collapsed_width) {
|
||||||
|
nx2 = nx1 + Math.round(node._collapsed_width);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (x1 == -1 || nx1 < x1) {
|
||||||
|
x1 = nx1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (y1 == -1 || ny1 < y1) {
|
||||||
|
y1 = ny1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (x2 == -1 || nx2 > x2) {
|
||||||
|
x2 = nx2;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (y2 == -1 || ny2 > y2) {
|
||||||
|
y2 = ny2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var padding = 10;
|
||||||
|
|
||||||
|
y1 = y1 - Math.round(group.font_size * 1.4);
|
||||||
|
|
||||||
|
group.pos = [x1 - padding, y1 - padding];
|
||||||
|
group.size = [x2 - x1 + padding * 2, y2 - y1 + padding * 2];
|
||||||
|
}
|
||||||
|
|
||||||
app.registerExtension({
|
app.registerExtension({
|
||||||
name: "Comfy.GroupOptions",
|
name: "Comfy.GroupOptions",
|
||||||
setup() {
|
setup() {
|
||||||
@ -14,6 +69,17 @@ app.registerExtension({
|
|||||||
const options = orig.apply(this, arguments);
|
const options = orig.apply(this, arguments);
|
||||||
const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]);
|
const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]);
|
||||||
if (!group) {
|
if (!group) {
|
||||||
|
options.push({
|
||||||
|
content: "Add Group For Selected Nodes",
|
||||||
|
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
|
||||||
|
callback: () => {
|
||||||
|
var group = new LiteGraph.LGraphGroup();
|
||||||
|
addNodesToGroup(group, this.selected_nodes)
|
||||||
|
app.canvas.graph.add(group);
|
||||||
|
this.graph.change();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
return options;
|
return options;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -21,6 +87,15 @@ app.registerExtension({
|
|||||||
group.recomputeInsideNodes();
|
group.recomputeInsideNodes();
|
||||||
const nodesInGroup = group._nodes;
|
const nodesInGroup = group._nodes;
|
||||||
|
|
||||||
|
options.push({
|
||||||
|
content: "Add Selected Nodes To Group",
|
||||||
|
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
|
||||||
|
callback: () => {
|
||||||
|
addNodesToGroup(group, this.selected_nodes)
|
||||||
|
this.graph.change();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
// No nodes in group, return default options
|
// No nodes in group, return default options
|
||||||
if (nodesInGroup.length === 0) {
|
if (nodesInGroup.length === 0) {
|
||||||
return options;
|
return options;
|
||||||
@ -38,6 +113,14 @@ app.registerExtension({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
options.push({
|
||||||
|
content: "Fit Group To Nodes",
|
||||||
|
callback: () => {
|
||||||
|
addNodesToGroup(group)
|
||||||
|
this.graph.change();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
options.push({
|
options.push({
|
||||||
content: "Select Nodes",
|
content: "Select Nodes",
|
||||||
callback: () => {
|
callback: () => {
|
||||||
|
|||||||
@ -22,6 +22,15 @@ class ManageTemplates extends ComfyDialog {
|
|||||||
super();
|
super();
|
||||||
this.element.classList.add("comfy-manage-templates");
|
this.element.classList.add("comfy-manage-templates");
|
||||||
this.templates = this.load();
|
this.templates = this.load();
|
||||||
|
|
||||||
|
this.importInput = $el("input", {
|
||||||
|
type: "file",
|
||||||
|
accept: ".json",
|
||||||
|
multiple: true,
|
||||||
|
style: {display: "none"},
|
||||||
|
parent: document.body,
|
||||||
|
onchange: () => this.importAll(),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
createButtons() {
|
createButtons() {
|
||||||
@ -34,6 +43,22 @@ class ManageTemplates extends ComfyDialog {
|
|||||||
onclick: () => this.save(),
|
onclick: () => this.save(),
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
btns.unshift(
|
||||||
|
$el("button", {
|
||||||
|
type: "button",
|
||||||
|
textContent: "Export",
|
||||||
|
onclick: () => this.exportAll(),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
btns.unshift(
|
||||||
|
$el("button", {
|
||||||
|
type: "button",
|
||||||
|
textContent: "Import",
|
||||||
|
onclick: () => {
|
||||||
|
this.importInput.click();
|
||||||
|
},
|
||||||
|
})
|
||||||
|
);
|
||||||
return btns;
|
return btns;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,6 +94,52 @@ class ManageTemplates extends ComfyDialog {
|
|||||||
localStorage.setItem(id, JSON.stringify(this.templates));
|
localStorage.setItem(id, JSON.stringify(this.templates));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async importAll() {
|
||||||
|
for (const file of this.importInput.files) {
|
||||||
|
if (file.type === "application/json" || file.name.endsWith(".json")) {
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onload = async () => {
|
||||||
|
var importFile = JSON.parse(reader.result);
|
||||||
|
if (importFile && importFile?.templates) {
|
||||||
|
for (const template of importFile.templates) {
|
||||||
|
if (template?.name && template?.data) {
|
||||||
|
this.templates.push(template);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.store();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
await reader.readAsText(file);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
this.importInput.value = null;
|
||||||
|
|
||||||
|
this.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
exportAll() {
|
||||||
|
if (this.templates.length == 0) {
|
||||||
|
alert("No templates to export.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const json = JSON.stringify({templates: this.templates}, null, 2); // convert the data to a JSON string
|
||||||
|
const blob = new Blob([json], {type: "application/json"});
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const a = $el("a", {
|
||||||
|
href: url,
|
||||||
|
download: "node_templates.json",
|
||||||
|
style: {display: "none"},
|
||||||
|
parent: document.body,
|
||||||
|
});
|
||||||
|
a.click();
|
||||||
|
setTimeout(function () {
|
||||||
|
a.remove();
|
||||||
|
window.URL.revokeObjectURL(url);
|
||||||
|
}, 0);
|
||||||
|
}
|
||||||
|
|
||||||
show() {
|
show() {
|
||||||
// Show list of template names + delete button
|
// Show list of template names + delete button
|
||||||
super.show(
|
super.show(
|
||||||
@ -97,19 +168,48 @@ class ManageTemplates extends ComfyDialog {
|
|||||||
}),
|
}),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
$el("button", {
|
$el(
|
||||||
textContent: "Delete",
|
"div",
|
||||||
style: {
|
{},
|
||||||
fontSize: "12px",
|
[
|
||||||
color: "red",
|
$el("button", {
|
||||||
fontWeight: "normal",
|
textContent: "Export",
|
||||||
},
|
style: {
|
||||||
onclick: (e) => {
|
fontSize: "12px",
|
||||||
nameInput.value = "";
|
fontWeight: "normal",
|
||||||
e.target.style.display = "none";
|
},
|
||||||
e.target.previousElementSibling.style.display = "none";
|
onclick: (e) => {
|
||||||
},
|
const json = JSON.stringify({templates: [t]}, null, 2); // convert the data to a JSON string
|
||||||
}),
|
const blob = new Blob([json], {type: "application/json"});
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const a = $el("a", {
|
||||||
|
href: url,
|
||||||
|
download: (nameInput.value || t.name) + ".json",
|
||||||
|
style: {display: "none"},
|
||||||
|
parent: document.body,
|
||||||
|
});
|
||||||
|
a.click();
|
||||||
|
setTimeout(function () {
|
||||||
|
a.remove();
|
||||||
|
window.URL.revokeObjectURL(url);
|
||||||
|
}, 0);
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
$el("button", {
|
||||||
|
textContent: "Delete",
|
||||||
|
style: {
|
||||||
|
fontSize: "12px",
|
||||||
|
color: "red",
|
||||||
|
fontWeight: "normal",
|
||||||
|
},
|
||||||
|
onclick: (e) => {
|
||||||
|
nameInput.value = "";
|
||||||
|
e.target.parentElement.style.display = "none";
|
||||||
|
e.target.parentElement.previousElementSibling.style.display = "none";
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
]
|
||||||
|
),
|
||||||
];
|
];
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
@ -164,19 +264,17 @@ app.registerExtension({
|
|||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
|
|
||||||
if (subItems.length) {
|
subItems.push(null, {
|
||||||
subItems.push(null, {
|
content: "Manage",
|
||||||
content: "Manage",
|
callback: () => manage.show(),
|
||||||
callback: () => manage.show(),
|
});
|
||||||
});
|
|
||||||
|
|
||||||
options.push({
|
options.push({
|
||||||
content: "Node Templates",
|
content: "Node Templates",
|
||||||
submenu: {
|
submenu: {
|
||||||
options: subItems,
|
options: subItems,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
|
||||||
|
|
||||||
return options;
|
return options;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -3796,7 +3796,7 @@
|
|||||||
out = out || new Float32Array(4);
|
out = out || new Float32Array(4);
|
||||||
out[0] = this.pos[0] - 4;
|
out[0] = this.pos[0] - 4;
|
||||||
out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT;
|
out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT;
|
||||||
out[2] = this.size[0] + 4;
|
out[2] = this.flags.collapsed ? (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) : this.size[0] + 4;
|
||||||
out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT;
|
out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT;
|
||||||
|
|
||||||
if (this.onBounding) {
|
if (this.onBounding) {
|
||||||
|
|||||||
@ -492,7 +492,7 @@ export class ComfyApp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (this.imgs && this.imgs.length) {
|
if (this.imgs && this.imgs.length) {
|
||||||
const canvas = graph.list_of_graphcanvas[0];
|
const canvas = app.graph.list_of_graphcanvas[0];
|
||||||
const mouse = canvas.graph_mouse;
|
const mouse = canvas.graph_mouse;
|
||||||
if (!canvas.pointer_is_down && this.pointerDown) {
|
if (!canvas.pointer_is_down && this.pointerDown) {
|
||||||
if (mouse[0] === this.pointerDown.pos[0] && mouse[1] === this.pointerDown.pos[1]) {
|
if (mouse[0] === this.pointerDown.pos[0] && mouse[1] === this.pointerDown.pos[1]) {
|
||||||
@ -1430,6 +1430,43 @@ export class ComfyApp {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
loadTemplateData(templateData) {
|
||||||
|
if (!templateData?.templates) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const old = localStorage.getItem("litegrapheditor_clipboard");
|
||||||
|
|
||||||
|
var maxY, nodeBottom, node;
|
||||||
|
|
||||||
|
for (const template of templateData.templates) {
|
||||||
|
if (!template?.data) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
localStorage.setItem("litegrapheditor_clipboard", template.data);
|
||||||
|
app.canvas.pasteFromClipboard();
|
||||||
|
|
||||||
|
// Move mouse position down to paste the next template below
|
||||||
|
|
||||||
|
maxY = false;
|
||||||
|
|
||||||
|
for (const i in app.canvas.selected_nodes) {
|
||||||
|
node = app.canvas.selected_nodes[i];
|
||||||
|
|
||||||
|
nodeBottom = node.pos[1] + node.size[1];
|
||||||
|
|
||||||
|
if (maxY === false || nodeBottom > maxY) {
|
||||||
|
maxY = nodeBottom;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
app.canvas.graph_mouse[1] = maxY + 50;
|
||||||
|
}
|
||||||
|
|
||||||
|
localStorage.setItem("litegrapheditor_clipboard", old);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Populates the graph with the specified workflow data
|
* Populates the graph with the specified workflow data
|
||||||
* @param {*} graphData A serialized graph object
|
* @param {*} graphData A serialized graph object
|
||||||
@ -1775,7 +1812,12 @@ export class ComfyApp {
|
|||||||
} else if (file.type === "application/json" || file.name?.endsWith(".json")) {
|
} else if (file.type === "application/json" || file.name?.endsWith(".json")) {
|
||||||
const reader = new FileReader();
|
const reader = new FileReader();
|
||||||
reader.onload = async () => {
|
reader.onload = async () => {
|
||||||
await this.loadGraphData(JSON.parse(reader.result));
|
var jsonContent = JSON.parse(reader.result);
|
||||||
|
if (jsonContent?.templates) {
|
||||||
|
this.loadTemplateData(jsonContent);
|
||||||
|
} else {
|
||||||
|
await this.loadGraphData(jsonContent);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
reader.readAsText(file);
|
reader.readAsText(file);
|
||||||
} else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) {
|
} else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user