Merge remote-tracking branch 'origin/master' into group-nodes

This commit is contained in:
pythongosssss 2023-10-21 17:41:10 +01:00
commit 288a1c6242
27 changed files with 907 additions and 463 deletions

View File

@ -10,8 +10,17 @@ jobs:
- uses: actions/setup-node@v3 - uses: actions/setup-node@v3
with: with:
node-version: 18 node-version: 18
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Run Tests - name: Run Tests
run: | run: |
npm install npm ci
npm run test:generate
npm test npm test
working-directory: ./tests-ui working-directory: ./tests-ui

1
.gitignore vendored
View File

@ -14,3 +14,4 @@ venv/
/web/extensions/* /web/extensions/*
!/web/extensions/logging.js.example !/web/extensions/logging.js.example
!/web/extensions/core/ !/web/extensions/core/
/tests-ui/data/object_info.json

View File

@ -92,8 +92,11 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
else: elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
else:
return None
clip = ClipVisionModel(json_config) clip = ClipVisionModel(json_config)
m, u = clip.load_sd(sd) m, u = clip.load_sd(sd)
if len(m) > 0: if len(m) > 0:

View File

@ -31,6 +31,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
vae = None vae = None
if output_vae: if output_vae:
vae = comfy.sd.VAE(ckpt_path=vae_path) sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
return (unet, clip, vae) return (unet, clip, vae)

View File

@ -713,8 +713,8 @@ class UniPC:
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
): ):
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end # t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start # t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device device = x.device
steps = len(timesteps) - 1 steps = len(timesteps) - 1
if method == 'multistep': if method == 'multistep':
@ -769,8 +769,8 @@ class UniPC:
callback(step_index, model_prev_list[-1], x, steps) callback(step_index, model_prev_list[-1], x, steps)
else: else:
raise NotImplementedError() raise NotImplementedError()
if denoise_to_zero: # if denoise_to_zero:
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) # x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
return x return x
@ -833,21 +833,33 @@ def expand_dims(v, dims):
return v[(...,) + (None,)*(dims - 1)] return v[(...,) + (None,)*(dims - 1)]
class SigmaConvert:
schedule = ""
def marginal_log_mean_coeff(self, sigma):
return 0.5 * torch.log(1 / ((sigma * sigma) + 1))
def marginal_alpha(self, t):
return torch.exp(self.marginal_log_mean_coeff(t))
def marginal_std(self, t):
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
def marginal_lambda(self, t):
"""
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
"""
log_mean_coeff = self.marginal_log_mean_coeff(t)
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
return log_mean_coeff - log_std
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
to_zero = False timesteps = sigmas.clone()
if sigmas[-1] == 0: if sigmas[-1] == 0:
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0] timesteps = sigmas[:]
to_zero = True timesteps[-1] = 0.001
else: else:
timesteps = sigmas.clone() timesteps = sigmas.clone()
ns = SigmaConvert()
alphas_cumprod = model.inner_model.alphas_cumprod
for s in range(timesteps.shape[0]):
timesteps[s] = (model.sigma_to_discrete_timestep(timesteps[s]) / 1000) + (1 / len(alphas_cumprod))
ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
if image is not None: if image is not None:
img = image * ns.marginal_alpha(timesteps[0]) img = image * ns.marginal_alpha(timesteps[0])
@ -859,16 +871,10 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
else: else:
img = noise img = noise
if to_zero:
timesteps[-1] = (1 / len(alphas_cumprod))
device = noise.device
model_type = "noise" model_type = "noise"
model_fn = model_wrapper( model_fn = model_wrapper(
model.predict_eps_discrete_timestep, model.predict_eps_sigma,
ns, ns,
model_type=model_type, model_type=model_type,
guidance_type="uncond", guidance_type="uncond",
@ -878,6 +884,5 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
order = min(3, len(timesteps) - 1) order = min(3, len(timesteps) - 1)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
if not to_zero: x /= ns.marginal_alpha(timesteps[-1])
x /= ns.marginal_alpha(timesteps[-1])
return x return x

View File

@ -97,6 +97,10 @@ class DiscreteSchedule(nn.Module):
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5) input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim) return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
def predict_eps_sigma(self, input, sigma, **kwargs):
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
class DiscreteEpsDDPMDenoiser(DiscreteSchedule): class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output eps (the predicted """A wrapper for discrete schedule DDPM models that output eps (the predicted
noise).""" noise)."""

View File

@ -2,67 +2,66 @@ import torch
# import pytorch_lightning as pl # import pytorch_lightning as pl
import torch.nn.functional as F import torch.nn.functional as F
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
from comfy.ldm.modules.diffusionmodules.model import Encoder, Decoder
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from comfy.ldm.util import instantiate_from_config from comfy.ldm.util import instantiate_from_config
from comfy.ldm.modules.ema import LitEma from comfy.ldm.modules.ema import LitEma
# class AutoencoderKL(pl.LightningModule): class DiagonalGaussianRegularizer(torch.nn.Module):
class AutoencoderKL(torch.nn.Module): def __init__(self, sample: bool = True):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False
):
super().__init__() super().__init__()
self.learn_logvar = learn_logvar self.sample = sample
self.image_key = image_key
self.encoder = Encoder(**ddconfig) def get_trainable_parameters(self) -> Any:
self.decoder = Decoder(**ddconfig) yield from ()
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"] def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) log = dict()
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) posterior = DiagonalGaussianDistribution(z)
self.embed_dim = embed_dim if self.sample:
if colorize_nlabels is not None: z = posterior.sample()
assert type(colorize_nlabels)==int else:
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) z = posterior.mode()
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log["kl_loss"] = kl_loss
return z, log
class AbstractAutoencoder(torch.nn.Module):
"""
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
unCLIP models, etc. Hence, it is fairly general, and specific features
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
"""
def __init__(
self,
ema_decay: Union[None, float] = None,
monitor: Union[None, str] = None,
input_key: str = "jpg",
**kwargs,
):
super().__init__()
self.input_key = input_key
self.use_ema = ema_decay is not None
if monitor is not None: if monitor is not None:
self.monitor = monitor self.monitor = monitor
self.use_ema = ema_decay is not None
if self.use_ema: if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay) self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None: def get_input(self, batch) -> Any:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) raise NotImplementedError()
def init_from_ckpt(self, path, ignore_keys=list()): def on_train_batch_end(self, *args, **kwargs):
if path.lower().endswith(".safetensors"): # for EMA computation
import safetensors.torch if self.use_ema:
sd = safetensors.torch.load_file(path, device="cpu") self.model_ema(self)
else:
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
@contextmanager @contextmanager
def ema_scope(self, context=None): def ema_scope(self, context=None):
@ -70,154 +69,159 @@ class AutoencoderKL(torch.nn.Module):
self.model_ema.store(self.parameters()) self.model_ema.store(self.parameters())
self.model_ema.copy_to(self) self.model_ema.copy_to(self)
if context is not None: if context is not None:
print(f"{context}: Switched to EMA weights") logpy.info(f"{context}: Switched to EMA weights")
try: try:
yield None yield None
finally: finally:
if self.use_ema: if self.use_ema:
self.model_ema.restore(self.parameters()) self.model_ema.restore(self.parameters())
if context is not None: if context is not None:
print(f"{context}: Restored training weights") logpy.info(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs): def encode(self, *args, **kwargs) -> torch.Tensor:
if self.use_ema: raise NotImplementedError("encode()-method of abstract base class called")
self.model_ema(self)
def encode(self, x): def decode(self, *args, **kwargs) -> torch.Tensor:
h = self.encoder(x) raise NotImplementedError("decode()-method of abstract base class called")
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z): def instantiate_optimizer_from_config(self, params, lr, cfg):
z = self.post_quant_conv(z) logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
dec = self.decoder(z) return get_obj_from_str(cfg["target"])(
return dec params, lr=lr, **cfg.get("params", dict())
)
def forward(self, input, sample_posterior=True): def configure_optimizers(self) -> Any:
posterior = self.encode(input) raise NotImplementedError()
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx): class AutoencodingEngine(AbstractAutoencoder):
inputs = self.get_input(batch, self.image_key) """
reconstructions, posterior = self(inputs) Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
(we also restore them explicitly as special cases for legacy reasons).
Regularizations such as KL or VQ are moved to the regularizer class.
"""
if optimizer_idx == 0: def __init__(
# train encoder+decoder+logvar self,
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, *args,
last_layer=self.get_last_layer(), split="train") encoder_config: Dict,
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) decoder_config: Dict,
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) regularizer_config: Dict,
return aeloss **kwargs,
):
super().__init__(*args, **kwargs)
if optimizer_idx == 1: self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
# train the discriminator self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, self.regularization: AbstractRegularizer = instantiate_from_config(
last_layer=self.get_last_layer(), split="train") regularizer_config
)
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, postfix=""):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
if self.learn_logvar:
print(f"{self.__class__.__name__}: Learning logvar")
ae_params_list.append(self.loss.logvar)
opt_ae = torch.optim.Adam(ae_params_list,
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self): def get_last_layer(self):
return self.decoder.conv_out.weight return self.decoder.get_last_layer()
@torch.no_grad() def encode(
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): self,
log = dict() x: torch.Tensor,
x = self.get_input(batch, self.image_key) return_reg_log: bool = False,
x = x.to(self.device) unregularized: bool = False,
if not only_inputs: ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
xrec, posterior = self(x) z = self.encoder(x)
if x.shape[1] > 3: if unregularized:
# colorize with random projection return z, dict()
assert xrec.shape[1] > 3 z, reg_log = self.regularization(z)
x = self.to_rgb(x) if return_reg_log:
xrec = self.to_rgb(xrec) return z, reg_log
log["samples"] = self.decode(torch.randn_like(posterior.sample())) return z
log["reconstructions"] = xrec
if log_ema or self.use_ema:
with self.ema_scope():
xrec_ema, posterior_ema = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
log["reconstructions_ema"] = xrec_ema
log["inputs"] = x
return log
def to_rgb(self, x): def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
assert self.image_key == "segmentation" x = self.decoder(z, **kwargs)
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x return x
def forward(
self, x: torch.Tensor, **additional_decode_kwargs
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True)
dec = self.decode(z, **additional_decode_kwargs)
return z, dec, reg_log
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs): class AutoencodingEngineLegacy(AutoencodingEngine):
return x def __init__(self, embed_dim: int, **kwargs):
self.max_batch_size = kwargs.pop("max_batch_size", None)
ddconfig = kwargs.pop("ddconfig")
super().__init__(
encoder_config={
"target": "comfy.ldm.modules.diffusionmodules.model.Encoder",
"params": ddconfig,
},
decoder_config={
"target": "comfy.ldm.modules.diffusionmodules.model.Decoder",
"params": ddconfig,
},
**kwargs,
)
self.quant_conv = torch.nn.Conv2d(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim,
1,
)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
def decode(self, x, *args, **kwargs): def get_autoencoder_params(self) -> list:
return x params = super().get_autoencoder_params()
return params
def quantize(self, x, *args, **kwargs): def encode(
if self.vq_interface: self, x: torch.Tensor, return_reg_log: bool = False
return x, None, [None, None, None] ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
return x if self.max_batch_size is None:
z = self.encoder(x)
z = self.quant_conv(z)
else:
N = x.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
z = list()
for i_batch in range(n_batches):
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
z_batch = self.quant_conv(z_batch)
z.append(z_batch)
z = torch.cat(z, 0)
def forward(self, x, *args, **kwargs): z, reg_log = self.regularization(z)
return x if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
if self.max_batch_size is None:
dec = self.post_quant_conv(z)
dec = self.decoder(dec, **decoder_kwargs)
else:
N = z.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
dec = list()
for i_batch in range(n_batches):
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
dec.append(dec_batch)
dec = torch.cat(dec, 0)
return dec
class AutoencoderKL(AutoencodingEngineLegacy):
def __init__(self, **kwargs):
if "lossconfig" in kwargs:
kwargs["loss_config"] = kwargs.pop("lossconfig")
super().__init__(
regularizer_config={
"target": (
"comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"
)
},
**kwargs,
)

View File

@ -285,15 +285,14 @@ def attention_pytorch(q, k, v, heads, mask=None):
) )
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if exists(mask):
raise NotImplementedError
out = ( out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head) out.transpose(1, 2).reshape(b, -1, heads * dim_head)
) )
return out return out
optimized_attention = attention_basic optimized_attention = attention_basic
optimized_attention_masked = attention_basic
if model_management.xformers_enabled(): if model_management.xformers_enabled():
print("Using xformers cross attention") print("Using xformers cross attention")
@ -309,6 +308,9 @@ else:
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad optimized_attention = attention_sub_quad
if model_management.pytorch_attention_enabled():
optimized_attention_masked = attention_pytorch
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__() super().__init__()
@ -334,7 +336,10 @@ class CrossAttention(nn.Module):
else: else:
v = self.to_v(context) v = self.to_v(context)
out = optimized_attention(q, k, v, self.heads, mask) if mask is None:
out = optimized_attention(q, k, v, self.heads)
else:
out = optimized_attention_masked(q, k, v, self.heads, mask)
return self.to_out(out) return self.to_out(out)

View File

@ -193,6 +193,52 @@ def slice_attention(q, k, v):
return r1 return r1
def normal_attention(q, k, v):
# compute attention
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
v = v.reshape(b,c,h*w)
r1 = slice_attention(q, k, v)
h_ = r1.reshape(b,c,h,w)
del r1
return h_
def xformers_attention(q, k, v):
# compute attention
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
(q, k, v),
)
try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
out = out.transpose(1, 2).reshape(B, C, H, W)
except NotImplementedError as e:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
return out
def pytorch_attention(q, k, v):
# compute attention
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
)
try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W)
except model_management.OOM_EXCEPTION as e:
print("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
return out
class AttnBlock(nn.Module): class AttnBlock(nn.Module):
def __init__(self, in_channels): def __init__(self, in_channels):
super().__init__() super().__init__()
@ -220,6 +266,16 @@ class AttnBlock(nn.Module):
stride=1, stride=1,
padding=0) padding=0)
if model_management.xformers_enabled_vae():
print("Using xformers attention in VAE")
self.optimized_attention = xformers_attention
elif model_management.pytorch_attention_enabled():
print("Using pytorch attention in VAE")
self.optimized_attention = pytorch_attention
else:
print("Using split attention in VAE")
self.optimized_attention = normal_attention
def forward(self, x): def forward(self, x):
h_ = x h_ = x
h_ = self.norm(h_) h_ = self.norm(h_)
@ -227,149 +283,15 @@ class AttnBlock(nn.Module):
k = self.k(h_) k = self.k(h_)
v = self.v(h_) v = self.v(h_)
# compute attention h_ = self.optimized_attention(q, k, v)
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
v = v.reshape(b,c,h*w)
r1 = slice_attention(q, k, v)
h_ = r1.reshape(b,c,h,w)
del r1
h_ = self.proj_out(h_) h_ = self.proj_out(h_)
return x+h_ return x+h_
class MemoryEfficientAttnBlock(nn.Module):
"""
Uses xformers efficient implementation,
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
Note: this is a single-head self-attention operation
"""
#
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.attention_op: Optional[Any] = None
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
(q, k, v),
)
try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
out = out.transpose(1, 2).reshape(B, C, H, W)
except NotImplementedError as e:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = self.proj_out(out)
return x+out
class MemoryEfficientAttnBlockPytorch(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = comfy.ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.attention_op: Optional[Any] = None
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
)
try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W)
except model_management.OOM_EXCEPTION as e:
print("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
out = self.proj_out(out)
return x+out
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' return AttnBlock(in_channels)
if model_management.xformers_enabled_vae() and attn_type == "vanilla":
attn_type = "vanilla-xformers"
elif model_management.pytorch_attention_enabled() and attn_type == "vanilla":
attn_type = "vanilla-pytorch"
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
assert attn_kwargs is None
return AttnBlock(in_channels)
elif attn_type == "vanilla-xformers":
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
return MemoryEfficientAttnBlock(in_channels)
elif attn_type == "vanilla-pytorch":
return MemoryEfficientAttnBlockPytorch(in_channels)
elif attn_type == "none":
return nn.Identity(in_channels)
else:
raise NotImplementedError()
class Model(nn.Module): class Model(nn.Module):
@ -619,7 +541,10 @@ class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
attn_type="vanilla", **ignorekwargs): conv_out_op=comfy.ops.Conv2d,
resnet_op=ResnetBlock,
attn_op=AttnBlock,
**ignorekwargs):
super().__init__() super().__init__()
if use_linear_attn: attn_type = "linear" if use_linear_attn: attn_type = "linear"
self.ch = ch self.ch = ch
@ -648,12 +573,12 @@ class Decoder(nn.Module):
# middle # middle
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, self.mid.block_1 = resnet_op(in_channels=block_in,
out_channels=block_in, out_channels=block_in,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout) dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.attn_1 = attn_op(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, self.mid.block_2 = resnet_op(in_channels=block_in,
out_channels=block_in, out_channels=block_in,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout) dropout=dropout)
@ -665,13 +590,13 @@ class Decoder(nn.Module):
attn = nn.ModuleList() attn = nn.ModuleList()
block_out = ch*ch_mult[i_level] block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1): for i_block in range(self.num_res_blocks+1):
block.append(ResnetBlock(in_channels=block_in, block.append(resnet_op(in_channels=block_in,
out_channels=block_out, out_channels=block_out,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout)) dropout=dropout))
block_in = block_out block_in = block_out
if curr_res in attn_resolutions: if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type)) attn.append(attn_op(block_in))
up = nn.Module() up = nn.Module()
up.block = block up.block = block
up.attn = attn up.attn = attn
@ -682,13 +607,13 @@ class Decoder(nn.Module):
# end # end
self.norm_out = Normalize(block_in) self.norm_out = Normalize(block_in)
self.conv_out = comfy.ops.Conv2d(block_in, self.conv_out = conv_out_op(block_in,
out_ch, out_ch,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
def forward(self, z): def forward(self, z, **kwargs):
#assert z.shape[1:] == self.z_shape[1:] #assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape self.last_z_shape = z.shape
@ -699,16 +624,16 @@ class Decoder(nn.Module):
h = self.conv_in(z) h = self.conv_in(z)
# middle # middle
h = self.mid.block_1(h, temb) h = self.mid.block_1(h, temb, **kwargs)
h = self.mid.attn_1(h) h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, temb) h = self.mid.block_2(h, temb, **kwargs)
# upsampling # upsampling
for i_level in reversed(range(self.num_resolutions)): for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1): for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb) h = self.up[i_level].block[i_block](h, temb, **kwargs)
if len(self.up[i_level].attn) > 0: if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h) h = self.up[i_level].attn[i_block](h, **kwargs)
if i_level != 0: if i_level != 0:
h = self.up[i_level].upsample(h) h = self.up[i_level].upsample(h)
@ -718,7 +643,7 @@ class Decoder(nn.Module):
h = self.norm_out(h) h = self.norm_out(h)
h = nonlinearity(h) h = nonlinearity(h)
h = self.conv_out(h) h = self.conv_out(h, **kwargs)
if self.tanh_out: if self.tanh_out:
h = torch.tanh(h) h = torch.tanh(h)
return h return h

View File

@ -26,6 +26,7 @@ class BaseModel(torch.nn.Module):
self.adm_channels = unet_config.get("adm_in_channels", None) self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None: if self.adm_channels is None:
self.adm_channels = 0 self.adm_channels = 0
self.inpaint_model = False
print("model_type", model_type.name) print("model_type", model_type.name)
print("adm", self.adm_channels) print("adm", self.adm_channels)
@ -71,6 +72,38 @@ class BaseModel(torch.nn.Module):
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
return None return None
def cond_concat(self, **kwargs):
if self.inpaint_model:
concat_keys = ("mask", "masked_image")
cond_concat = []
denoise_mask = kwargs.get("denoise_mask", None)
latent_image = kwargs.get("latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
return blank_image
for ck in concat_keys:
if denoise_mask is not None:
if ck == "mask":
cond_concat.append(denoise_mask[:,:1].to(device))
elif ck == "masked_image":
cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise))
return cond_concat
return None
def load_model_weights(self, sd, unet_prefix=""): def load_model_weights(self, sd, unet_prefix=""):
to_load = {} to_load = {}
keys = list(sd.keys()) keys = list(sd.keys())
@ -112,7 +145,7 @@ class BaseModel(torch.nn.Module):
return {**unet_state_dict, **vae_state_dict, **clip_state_dict} return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
def set_inpaint(self): def set_inpaint(self):
self.concat_keys = ("mask", "masked_image") self.inpaint_model = True
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0):
adm_inputs = [] adm_inputs = []

View File

@ -667,7 +667,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
return False return False
#FP16 is just broken on these cards #FP16 is just broken on these cards
nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX"] nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"]
for x in nvidia_16_series: for x in nvidia_16_series:
if x in props.name: if x in props.name:
return False return False

View File

@ -98,6 +98,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
samples = samples.cpu() samples = samples.cpu()
cleanup_additional_models(models) cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
return samples return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
@ -109,5 +110,6 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent
samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.cpu() samples = samples.cpu()
cleanup_additional_models(models) cleanup_additional_models(models)
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
return samples return samples

View File

@ -14,8 +14,8 @@ def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns predicted noise #Returns predicted noise
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}, seed=None): def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): def get_area_and_mult(cond, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0) area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0 strength = 1.0
if 'timestep_start' in cond[1]: if 'timestep_start' in cond[1]:
@ -68,12 +68,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
conditionning = {} conditionning = {}
conditionning['c_crossattn'] = cond[0] conditionning['c_crossattn'] = cond[0]
if cond_concat_in is not None and len(cond_concat_in) > 0:
cropped = [] if 'concat' in cond[1]:
for x in cond_concat_in: cond_concat_in = cond[1]['concat']
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] if cond_concat_in is not None and len(cond_concat_in) > 0:
cropped.append(cr) cropped = []
conditionning['c_concat'] = torch.cat(cropped, dim=1) for x in cond_concat_in:
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
cropped.append(cr)
conditionning['c_concat'] = torch.cat(cropped, dim=1)
if adm_cond is not None: if adm_cond is not None:
conditionning['c_adm'] = adm_cond conditionning['c_adm'] = adm_cond
@ -173,7 +176,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
out['c_adm'] = torch.cat(c_adm) out['c_adm'] = torch.cat(c_adm)
return out return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options): def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
out_cond = torch.zeros_like(x_in) out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in)/100000.0 out_count = torch.ones_like(x_in)/100000.0
@ -185,14 +188,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
to_run = [] to_run = []
for x in cond: for x in cond:
p = get_area_and_mult(x, x_in, cond_concat_in, timestep) p = get_area_and_mult(x, x_in, timestep)
if p is None: if p is None:
continue continue
to_run += [(p, COND)] to_run += [(p, COND)]
if uncond is not None: if uncond is not None:
for x in uncond: for x in uncond:
p = get_area_and_mult(x, x_in, cond_concat_in, timestep) p = get_area_and_mult(x, x_in, timestep)
if p is None: if p is None:
continue continue
@ -286,7 +289,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if math.isclose(cond_scale, 1.0): if math.isclose(cond_scale, 1.0):
uncond = None uncond = None
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options) cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options)
if "sampler_cfg_function" in model_options: if "sampler_cfg_function" in model_options:
args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep} args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep}
return model_options["sampler_cfg_function"](args) return model_options["sampler_cfg_function"](args)
@ -307,8 +310,8 @@ class CFGNoisePredictor(torch.nn.Module):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = model
self.alphas_cumprod = model.alphas_cumprod self.alphas_cumprod = model.alphas_cumprod
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, seed=None): def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options, seed=seed) out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
return out return out
@ -316,11 +319,11 @@ class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}, seed=None): def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
if denoise_mask is not None: if denoise_mask is not None:
latent_mask = 1. - denoise_mask latent_mask = 1. - denoise_mask
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options, seed=seed) out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
if denoise_mask is not None: if denoise_mask is not None:
out *= denoise_mask out *= denoise_mask
@ -358,15 +361,6 @@ def sgm_scheduler(model, steps):
sigs += [0.0] sigs += [0.0]
return torch.FloatTensor(sigs) return torch.FloatTensor(sigs)
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
blank_image[:,3] *= 0.1380
return blank_image
def get_mask_aabb(masks): def get_mask_aabb(masks):
if masks.numel() == 0: if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int) return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
@ -543,6 +537,20 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
return conds return conds
def encode_cond(model_function, key, conds, device, **kwargs):
for t in range(len(conds)):
x = conds[t]
params = x[1].copy()
params["device"] = device
for k in kwargs:
if k not in params:
params[k] = kwargs[k]
out = model_function(**params)
if out is not None:
x[1] = x[1].copy()
x[1][key] = out
return conds
class Sampler: class Sampler:
def sample(self): def sample(self):
@ -662,31 +670,19 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
if latent_image is not None:
latent_image = model.process_latent_in(latent_image)
if model.is_adm(): if model.is_adm():
positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive") positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive")
negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative") negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
if latent_image is not None: if hasattr(model, 'cond_concat'):
latent_image = model.process_latent_in(latent_image) positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
cond_concat = None
if hasattr(model, 'concat_keys'): #inpaint
cond_concat = []
for ck in model.concat_keys:
if denoise_mask is not None:
if ck == "mask":
cond_concat.append(denoise_mask[:,:1])
elif ck == "masked_image":
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
else:
if ck == "mask":
cond_concat.append(torch.ones_like(noise)[:,:1])
elif ck == "masked_image":
cond_concat.append(blank_inpaint_image_like(noise))
extra_args["cond_concat"] = cond_concat
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return model.process_latent_out(samples.to(torch.float32)) return model.process_latent_out(samples.to(torch.float32))
@ -743,7 +739,7 @@ class KSampler:
sigmas = None sigmas = None
discard_penultimate_sigma = False discard_penultimate_sigma = False
if self.sampler in ['dpm_2', 'dpm_2_ancestral']: if self.sampler in ['dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2']:
steps += 1 steps += 1
discard_penultimate_sigma = True discard_penultimate_sigma = True

View File

@ -4,7 +4,7 @@ import math
from comfy import model_management from comfy import model_management
from .ldm.util import instantiate_from_config from .ldm.util import instantiate_from_config
from .ldm.models.autoencoder import AutoencoderKL from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
import yaml import yaml
import comfy.utils import comfy.utils
@ -140,21 +140,24 @@ class CLIP:
return self.patcher.get_key_patches() return self.patcher.get_key_patches()
class VAE: class VAE:
def __init__(self, ckpt_path=None, device=None, config=None): def __init__(self, sd=None, device=None, config=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
if config is None: if config is None:
#default SD1.x/SD2.x VAE parameters #default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss") self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else: else:
self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval() self.first_stage_model = self.first_stage_model.eval()
if ckpt_path is not None:
sd = comfy.utils.load_torch_file(ckpt_path) m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format if len(m) > 0:
sd = diffusers_convert.convert_vae_state_dict(sd) print("Missing VAE keys", m)
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if len(m) > 0: if len(u) > 0:
print("Missing VAE keys", m) print("Leftover VAE keys", u)
if device is None: if device is None:
device = model_management.vae_device() device = model_management.vae_device()
@ -183,7 +186,7 @@ class VAE:
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).sample().float() encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
@ -229,7 +232,7 @@ class VAE:
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
for x in range(0, pixel_samples.shape[0], batch_number): for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu().float() samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
@ -375,10 +378,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
model.load_model_weights(state_dict, "model.diffusion_model.") model.load_model_weights(state_dict, "model.diffusion_model.")
if output_vae: if output_vae:
w = WeightsLoader() vae_sd = comfy.utils.state_dict_prefix_replace(state_dict, {"first_stage_model.": ""}, filter_keys=True)
vae = VAE(config=vae_config) vae = VAE(sd=vae_sd, config=vae_config)
w.first_stage_model = vae.first_stage_model
load_model_weights(w, state_dict)
if output_clip: if output_clip:
w = WeightsLoader() w = WeightsLoader()
@ -427,18 +428,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model.load_model_weights(sd, "model.diffusion_model.") model.load_model_weights(sd, "model.diffusion_model.")
if output_vae: if output_vae:
vae = VAE() vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
w = WeightsLoader() vae = VAE(sd=vae_sd)
w.first_stage_model = vae.first_stage_model
load_model_weights(w, sd)
if output_clip: if output_clip:
w = WeightsLoader() w = WeightsLoader()
clip_target = model_config.clip_target() clip_target = model_config.clip_target()
clip = CLIP(clip_target, embedding_directory=embedding_directory) if clip_target is not None:
w.cond_stage_model = clip.cond_stage_model clip = CLIP(clip_target, embedding_directory=embedding_directory)
sd = model_config.process_clip_state_dict(sd) w.cond_stage_model = clip.cond_stage_model
load_model_weights(w, sd) sd = model_config.process_clip_state_dict(sd)
load_model_weights(w, sd)
left_over = sd.keys() left_over = sd.keys()
if len(left_over) > 0: if len(left_over) > 0:

View File

@ -47,12 +47,17 @@ def state_dict_key_replace(state_dict, keys_to_replace):
state_dict[keys_to_replace[x]] = state_dict.pop(x) state_dict[keys_to_replace[x]] = state_dict.pop(x)
return state_dict return state_dict
def state_dict_prefix_replace(state_dict, replace_prefix): def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False):
if filter_keys:
out = {}
else:
out = state_dict
for rp in replace_prefix: for rp in replace_prefix:
replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys())))
for x in replace: for x in replace:
state_dict[x[1]] = state_dict.pop(x[0]) w = state_dict.pop(x[0])
return state_dict out[x[1]] = w
return out
def transformers_convert(sd, prefix_from, prefix_to, number): def transformers_convert(sd, prefix_from, prefix_to, number):

View File

@ -61,7 +61,53 @@ class FreeU:
m.set_model_output_block_patch(output_block_patch) m.set_model_output_block_patch(output_block_patch)
return (m, ) return (m, )
class FreeU_V2:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}),
"b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}),
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"]
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
on_cpu_devices = {}
def output_block_patch(h, hsp, transformer_options):
scale = scale_dict.get(h.shape[1], None)
if scale is not None:
hidden_mean = h.mean(1).unsqueeze(1)
B = hidden_mean.shape[0]
hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * ((scale[0] - 1 ) * hidden_mean + 1)
if hsp.device not in on_cpu_devices:
try:
hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
except:
print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.")
on_cpu_devices[hsp.device] = True
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
else:
hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
return h, hsp
m = model.clone()
m.set_model_output_block_patch(output_block_patch)
return (m, )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"FreeU": FreeU, "FreeU": FreeU,
"FreeU_V2": FreeU_V2,
} }

View File

@ -19,6 +19,7 @@ def load_hypernetwork_patch(path, strength):
"tanh": torch.nn.Tanh, "tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid, "sigmoid": torch.nn.Sigmoid,
"softsign": torch.nn.Softsign, "softsign": torch.nn.Softsign,
"mish": torch.nn.Mish,
} }
if activation_func not in valid_activation: if activation_func not in valid_activation:
@ -42,7 +43,8 @@ def load_hypernetwork_patch(path, strength):
linears = list(map(lambda a: a[:-len(".weight")], linears)) linears = list(map(lambda a: a[:-len(".weight")], linears))
layers = [] layers = []
for i in range(len(linears)): i = 0
while i < len(linears):
lin_name = linears[i] lin_name = linears[i]
last_layer = (i == (len(linears) - 1)) last_layer = (i == (len(linears) - 1))
penultimate_layer = (i == (len(linears) - 2)) penultimate_layer = (i == (len(linears) - 2))
@ -56,10 +58,17 @@ def load_hypernetwork_patch(path, strength):
if (not last_layer) or (activate_output): if (not last_layer) or (activate_output):
layers.append(valid_activation[activation_func]()) layers.append(valid_activation[activation_func]())
if is_layer_norm: if is_layer_norm:
layers.append(torch.nn.LayerNorm(lin_weight.shape[0])) i += 1
ln_name = linears[i]
ln_weight = attn_weights['{}.weight'.format(ln_name)]
ln_bias = attn_weights['{}.bias'.format(ln_name)]
ln = torch.nn.LayerNorm(ln_weight.shape[0])
ln.load_state_dict({"weight": ln_weight, "bias": ln_bias})
layers.append(ln)
if use_dropout: if use_dropout:
if (not last_layer) and (not penultimate_layer or last_layer_dropout): if (not last_layer) and (not penultimate_layer or last_layer_dropout):
layers.append(torch.nn.Dropout(p=0.3)) layers.append(torch.nn.Dropout(p=0.3))
i += 1
output.append(torch.nn.Sequential(*layers)) output.append(torch.nn.Sequential(*layers))
out[dim] = torch.nn.ModuleList(output) out[dim] = torch.nn.ModuleList(output)

View File

@ -0,0 +1,83 @@
#Taken from: https://github.com/tfernd/HyperTile/
import math
from einops import rearrange
import random
def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int:
min_value = min(min_value, value)
# All big divisors of value (inclusive)
divisors = [i for i in range(min_value, value + 1) if value % i == 0]
ns = [value // i for i in divisors[:max_options]] # has at least 1 element
random.seed(counter)
idx = random.randint(0, len(ns) - 1)
return ns[idx]
class HyperTile:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}),
"swap_size": ("INT", {"default": 2, "min": 1, "max": 128}),
"max_depth": ("INT", {"default": 0, "min": 0, "max": 10}),
"scale_depth": ("BOOLEAN", {"default": False}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
model_channels = model.model.model_config.unet_config["model_channels"]
apply_to = set()
temp = model_channels
for x in range(max_depth + 1):
apply_to.add(temp)
temp *= 2
latent_tile_size = max(32, tile_size) // 8
self.temp = None
self.counter = 1
def hypertile_in(q, k, v, extra_options):
if q.shape[-1] in apply_to:
shape = extra_options["original_shape"]
aspect_ratio = shape[-1] / shape[-2]
hw = q.size(1)
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1
nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter)
self.counter += 1
nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter)
self.counter += 1
if nh * nw > 1:
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
self.temp = (nh, nw, h, w)
return q, k, v
return q, k, v
def hypertile_out(out, extra_options):
if self.temp is not None:
nh, nw, h, w = self.temp
self.temp = None
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
return out
m = model.clone()
m.set_model_attn1_patch(hypertile_in)
m.set_model_attn1_output_patch(hypertile_out)
return (m, )
NODE_CLASS_MAPPINGS = {
"HyperTile": HyperTile,
}

View File

@ -584,7 +584,8 @@ class VAELoader:
#TODO: scale factor? #TODO: scale factor?
def load_vae(self, vae_name): def load_vae(self, vae_name):
vae_path = folder_paths.get_full_path("vae", vae_name) vae_path = folder_paths.get_full_path("vae", vae_name)
vae = comfy.sd.VAE(ckpt_path=vae_path) sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
return (vae,) return (vae,)
class ControlNetLoader: class ControlNetLoader:
@ -1660,7 +1661,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"KSampler": "KSampler", "KSampler": "KSampler",
"KSamplerAdvanced": "KSampler (Advanced)", "KSamplerAdvanced": "KSampler (Advanced)",
# Loaders # Loaders
"CheckpointLoader": "Load Checkpoint (With Config)", "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
"CheckpointLoaderSimple": "Load Checkpoint", "CheckpointLoaderSimple": "Load Checkpoint",
"VAELoader": "Load VAE", "VAELoader": "Load VAE",
"LoraLoader": "Load LoRA", "LoraLoader": "Load LoRA",
@ -1795,7 +1796,8 @@ def init_custom_nodes():
"nodes_clip_sdxl.py", "nodes_clip_sdxl.py",
"nodes_canny.py", "nodes_canny.py",
"nodes_freelunch.py", "nodes_freelunch.py",
"nodes_custom_sampler.py" "nodes_custom_sampler.py",
"nodes_hypertile.py",
] ]
for node_file in extras_files: for node_file in extras_files:

View File

@ -47,7 +47,7 @@
" !git pull\n", " !git pull\n",
"\n", "\n",
"!echo -= Install dependencies =-\n", "!echo -= Install dependencies =-\n",
"!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117" "!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117"
] ]
}, },
{ {

View File

@ -1816,9 +1816,9 @@
} }
}, },
"node_modules/@babel/traverse": { "node_modules/@babel/traverse": {
"version": "7.23.0", "version": "7.23.2",
"resolved": "https://registry.npmjs.org/@babel/traverse/-/traverse-7.23.0.tgz", "resolved": "https://registry.npmjs.org/@babel/traverse/-/traverse-7.23.2.tgz",
"integrity": "sha512-t/QaEvyIoIkwzpiZ7aoSKK8kObQYeF7T2v+dazAYCb8SXtp58zEVkWW7zAnju8FNKNdr4ScAOEDmMItbyOmEYw==", "integrity": "sha512-azpe59SQ48qG6nu2CzcMLbxUudtN+dOM9kDbUqGq3HXUJRlo7i8fvPoxQUzYgLZ4cMVmuZgm8vvBpNeRhd6XSw==",
"dev": true, "dev": true,
"dependencies": { "dependencies": {
"@babel/code-frame": "^7.22.13", "@babel/code-frame": "^7.22.13",

View File

@ -4,7 +4,8 @@
"description": "UI tests", "description": "UI tests",
"main": "index.js", "main": "index.js",
"scripts": { "scripts": {
"test": "jest" "test": "jest",
"test:generate": "node setup.js"
}, },
"repository": { "repository": {
"type": "git", "type": "git",

87
tests-ui/setup.js Normal file
View File

@ -0,0 +1,87 @@
const { spawn } = require("child_process");
const { resolve } = require("path");
const { existsSync, mkdirSync, writeFileSync } = require("fs");
const http = require("http");
async function setup() {
// Wait up to 30s for it to start
let success = false;
let child;
for (let i = 0; i < 30; i++) {
try {
await new Promise((res, rej) => {
http
.get("http://127.0.0.1:8188/object_info", (resp) => {
let data = "";
resp.on("data", (chunk) => {
data += chunk;
});
resp.on("end", () => {
// Modify the response data to add some checkpoints
const objectInfo = JSON.parse(data);
objectInfo.CheckpointLoaderSimple.input.required.ckpt_name[0] = ["model1.safetensors", "model2.ckpt"];
data = JSON.stringify(objectInfo, undefined, "\t");
const outDir = resolve("./data");
if (!existsSync(outDir)) {
mkdirSync(outDir);
}
const outPath = resolve(outDir, "object_info.json");
console.log(`Writing ${Object.keys(objectInfo).length} nodes to ${outPath}`);
writeFileSync(outPath, data, {
encoding: "utf8",
});
res();
});
})
.on("error", rej);
});
success = true;
break;
} catch (error) {
console.log(i + "/30", error);
if (i === 0) {
// Start the server on first iteration if it fails to connect
console.log("Starting ComfyUI server...");
let python = resolve("../../python_embeded/python.exe");
let args;
let cwd;
if (existsSync(python)) {
args = ["-s", "ComfyUI/main.py"];
cwd = "../..";
} else {
python = "python";
args = ["main.py"];
cwd = "..";
}
args.push("--cpu");
console.log(python, ...args);
child = spawn(python, args, { cwd });
child.on("error", (err) => {
console.log(`Server error (${err})`);
i = 30;
});
child.on("exit", (code) => {
if (!success) {
console.log(`Server exited (${code})`);
i = 30;
}
});
}
await new Promise((r) => {
setTimeout(r, 1000);
});
}
}
child?.kill();
if (!success) {
throw new Error("Waiting for server failed...");
}
}
setup();

View File

@ -5,6 +5,61 @@ function setNodeMode(node, mode) {
node.graph.change(); node.graph.change();
} }
function addNodesToGroup(group, nodes=[]) {
var x1, y1, x2, y2;
var nx1, ny1, nx2, ny2;
var node;
x1 = y1 = x2 = y2 = -1;
nx1 = ny1 = nx2 = ny2 = -1;
for (var n of [group._nodes, nodes]) {
for (var i in n) {
node = n[i]
nx1 = node.pos[0]
ny1 = node.pos[1]
nx2 = node.pos[0] + node.size[0]
ny2 = node.pos[1] + node.size[1]
if (node.type != "Reroute") {
ny1 -= LiteGraph.NODE_TITLE_HEIGHT;
}
if (node.flags?.collapsed) {
ny2 = ny1 + LiteGraph.NODE_TITLE_HEIGHT;
if (node?._collapsed_width) {
nx2 = nx1 + Math.round(node._collapsed_width);
}
}
if (x1 == -1 || nx1 < x1) {
x1 = nx1;
}
if (y1 == -1 || ny1 < y1) {
y1 = ny1;
}
if (x2 == -1 || nx2 > x2) {
x2 = nx2;
}
if (y2 == -1 || ny2 > y2) {
y2 = ny2;
}
}
}
var padding = 10;
y1 = y1 - Math.round(group.font_size * 1.4);
group.pos = [x1 - padding, y1 - padding];
group.size = [x2 - x1 + padding * 2, y2 - y1 + padding * 2];
}
app.registerExtension({ app.registerExtension({
name: "Comfy.GroupOptions", name: "Comfy.GroupOptions",
setup() { setup() {
@ -14,6 +69,17 @@ app.registerExtension({
const options = orig.apply(this, arguments); const options = orig.apply(this, arguments);
const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]); const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]);
if (!group) { if (!group) {
options.push({
content: "Add Group For Selected Nodes",
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
callback: () => {
var group = new LiteGraph.LGraphGroup();
addNodesToGroup(group, this.selected_nodes)
app.canvas.graph.add(group);
this.graph.change();
}
});
return options; return options;
} }
@ -21,6 +87,15 @@ app.registerExtension({
group.recomputeInsideNodes(); group.recomputeInsideNodes();
const nodesInGroup = group._nodes; const nodesInGroup = group._nodes;
options.push({
content: "Add Selected Nodes To Group",
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
callback: () => {
addNodesToGroup(group, this.selected_nodes)
this.graph.change();
}
});
// No nodes in group, return default options // No nodes in group, return default options
if (nodesInGroup.length === 0) { if (nodesInGroup.length === 0) {
return options; return options;
@ -38,6 +113,14 @@ app.registerExtension({
} }
} }
options.push({
content: "Fit Group To Nodes",
callback: () => {
addNodesToGroup(group)
this.graph.change();
}
});
options.push({ options.push({
content: "Select Nodes", content: "Select Nodes",
callback: () => { callback: () => {

View File

@ -22,6 +22,15 @@ class ManageTemplates extends ComfyDialog {
super(); super();
this.element.classList.add("comfy-manage-templates"); this.element.classList.add("comfy-manage-templates");
this.templates = this.load(); this.templates = this.load();
this.importInput = $el("input", {
type: "file",
accept: ".json",
multiple: true,
style: {display: "none"},
parent: document.body,
onchange: () => this.importAll(),
});
} }
createButtons() { createButtons() {
@ -34,6 +43,22 @@ class ManageTemplates extends ComfyDialog {
onclick: () => this.save(), onclick: () => this.save(),
}) })
); );
btns.unshift(
$el("button", {
type: "button",
textContent: "Export",
onclick: () => this.exportAll(),
})
);
btns.unshift(
$el("button", {
type: "button",
textContent: "Import",
onclick: () => {
this.importInput.click();
},
})
);
return btns; return btns;
} }
@ -69,6 +94,52 @@ class ManageTemplates extends ComfyDialog {
localStorage.setItem(id, JSON.stringify(this.templates)); localStorage.setItem(id, JSON.stringify(this.templates));
} }
async importAll() {
for (const file of this.importInput.files) {
if (file.type === "application/json" || file.name.endsWith(".json")) {
const reader = new FileReader();
reader.onload = async () => {
var importFile = JSON.parse(reader.result);
if (importFile && importFile?.templates) {
for (const template of importFile.templates) {
if (template?.name && template?.data) {
this.templates.push(template);
}
}
this.store();
}
};
await reader.readAsText(file);
}
}
this.importInput.value = null;
this.close();
}
exportAll() {
if (this.templates.length == 0) {
alert("No templates to export.");
return;
}
const json = JSON.stringify({templates: this.templates}, null, 2); // convert the data to a JSON string
const blob = new Blob([json], {type: "application/json"});
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: "node_templates.json",
style: {display: "none"},
parent: document.body,
});
a.click();
setTimeout(function () {
a.remove();
window.URL.revokeObjectURL(url);
}, 0);
}
show() { show() {
// Show list of template names + delete button // Show list of template names + delete button
super.show( super.show(
@ -97,19 +168,48 @@ class ManageTemplates extends ComfyDialog {
}), }),
] ]
), ),
$el("button", { $el(
textContent: "Delete", "div",
style: { {},
fontSize: "12px", [
color: "red", $el("button", {
fontWeight: "normal", textContent: "Export",
}, style: {
onclick: (e) => { fontSize: "12px",
nameInput.value = ""; fontWeight: "normal",
e.target.style.display = "none"; },
e.target.previousElementSibling.style.display = "none"; onclick: (e) => {
}, const json = JSON.stringify({templates: [t]}, null, 2); // convert the data to a JSON string
}), const blob = new Blob([json], {type: "application/json"});
const url = URL.createObjectURL(blob);
const a = $el("a", {
href: url,
download: (nameInput.value || t.name) + ".json",
style: {display: "none"},
parent: document.body,
});
a.click();
setTimeout(function () {
a.remove();
window.URL.revokeObjectURL(url);
}, 0);
},
}),
$el("button", {
textContent: "Delete",
style: {
fontSize: "12px",
color: "red",
fontWeight: "normal",
},
onclick: (e) => {
nameInput.value = "";
e.target.parentElement.style.display = "none";
e.target.parentElement.previousElementSibling.style.display = "none";
},
}),
]
),
]; ];
}) })
) )
@ -164,19 +264,17 @@ app.registerExtension({
}, },
})); }));
if (subItems.length) { subItems.push(null, {
subItems.push(null, { content: "Manage",
content: "Manage", callback: () => manage.show(),
callback: () => manage.show(), });
});
options.push({ options.push({
content: "Node Templates", content: "Node Templates",
submenu: { submenu: {
options: subItems, options: subItems,
}, },
}); });
}
return options; return options;
}; };

View File

@ -3796,7 +3796,7 @@
out = out || new Float32Array(4); out = out || new Float32Array(4);
out[0] = this.pos[0] - 4; out[0] = this.pos[0] - 4;
out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT; out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT;
out[2] = this.size[0] + 4; out[2] = this.flags.collapsed ? (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) : this.size[0] + 4;
out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT; out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT;
if (this.onBounding) { if (this.onBounding) {

View File

@ -492,7 +492,7 @@ export class ComfyApp {
} }
if (this.imgs && this.imgs.length) { if (this.imgs && this.imgs.length) {
const canvas = graph.list_of_graphcanvas[0]; const canvas = app.graph.list_of_graphcanvas[0];
const mouse = canvas.graph_mouse; const mouse = canvas.graph_mouse;
if (!canvas.pointer_is_down && this.pointerDown) { if (!canvas.pointer_is_down && this.pointerDown) {
if (mouse[0] === this.pointerDown.pos[0] && mouse[1] === this.pointerDown.pos[1]) { if (mouse[0] === this.pointerDown.pos[0] && mouse[1] === this.pointerDown.pos[1]) {
@ -1430,6 +1430,43 @@ export class ComfyApp {
} }
} }
loadTemplateData(templateData) {
if (!templateData?.templates) {
return;
}
const old = localStorage.getItem("litegrapheditor_clipboard");
var maxY, nodeBottom, node;
for (const template of templateData.templates) {
if (!template?.data) {
continue;
}
localStorage.setItem("litegrapheditor_clipboard", template.data);
app.canvas.pasteFromClipboard();
// Move mouse position down to paste the next template below
maxY = false;
for (const i in app.canvas.selected_nodes) {
node = app.canvas.selected_nodes[i];
nodeBottom = node.pos[1] + node.size[1];
if (maxY === false || nodeBottom > maxY) {
maxY = nodeBottom;
}
}
app.canvas.graph_mouse[1] = maxY + 50;
}
localStorage.setItem("litegrapheditor_clipboard", old);
}
/** /**
* Populates the graph with the specified workflow data * Populates the graph with the specified workflow data
* @param {*} graphData A serialized graph object * @param {*} graphData A serialized graph object
@ -1775,7 +1812,12 @@ export class ComfyApp {
} else if (file.type === "application/json" || file.name?.endsWith(".json")) { } else if (file.type === "application/json" || file.name?.endsWith(".json")) {
const reader = new FileReader(); const reader = new FileReader();
reader.onload = async () => { reader.onload = async () => {
await this.loadGraphData(JSON.parse(reader.result)); var jsonContent = JSON.parse(reader.result);
if (jsonContent?.templates) {
this.loadTemplateData(jsonContent);
} else {
await this.loadGraphData(jsonContent);
}
}; };
reader.readAsText(file); reader.readAsText(file);
} else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) { } else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) {