diff --git a/.github/workflows/test-ui.yaml b/.github/workflows/test-ui.yaml index 1470d110c..950691755 100644 --- a/.github/workflows/test-ui.yaml +++ b/.github/workflows/test-ui.yaml @@ -10,8 +10,17 @@ jobs: - uses: actions/setup-node@v3 with: node-version: 18 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Install requirements + run: | + python -m pip install --upgrade pip + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install -r requirements.txt - name: Run Tests run: | - npm install + npm ci + npm run test:generate npm test working-directory: ./tests-ui diff --git a/.gitignore b/.gitignore index 98d91318d..43c038e41 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ venv/ /web/extensions/* !/web/extensions/logging.js.example !/web/extensions/core/ +/tests-ui/data/object_info.json \ No newline at end of file diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 1206c680d..e085186ef 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -92,8 +92,11 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json") elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") - else: + elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") + else: + return None + clip = ClipVisionModel(json_config) m, u = clip.load_sd(sd) if len(m) > 0: diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py index a52e0102b..c0b420e79 100644 --- a/comfy/diffusers_load.py +++ b/comfy/diffusers_load.py @@ -31,6 +31,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire vae = None if output_vae: - vae = comfy.sd.VAE(ckpt_path=vae_path) + sd = comfy.utils.load_torch_file(vae_path) + vae = comfy.sd.VAE(sd=sd) return (unet, clip, vae) diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 7e88bb9fa..58e030d04 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -713,8 +713,8 @@ class UniPC: method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False ): - t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end - t_T = self.noise_schedule.T if t_start is None else t_start + # t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + # t_T = self.noise_schedule.T if t_start is None else t_start device = x.device steps = len(timesteps) - 1 if method == 'multistep': @@ -769,8 +769,8 @@ class UniPC: callback(step_index, model_prev_list[-1], x, steps) else: raise NotImplementedError() - if denoise_to_zero: - x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + # if denoise_to_zero: + # x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) return x @@ -833,21 +833,33 @@ def expand_dims(v, dims): return v[(...,) + (None,)*(dims - 1)] +class SigmaConvert: + schedule = "" + def marginal_log_mean_coeff(self, sigma): + return 0.5 * torch.log(1 / ((sigma * sigma) + 1)) + + def marginal_alpha(self, t): + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): - to_zero = False + timesteps = sigmas.clone() if sigmas[-1] == 0: - timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0] - to_zero = True + timesteps = sigmas[:] + timesteps[-1] = 0.001 else: timesteps = sigmas.clone() - - alphas_cumprod = model.inner_model.alphas_cumprod - - for s in range(timesteps.shape[0]): - timesteps[s] = (model.sigma_to_discrete_timestep(timesteps[s]) / 1000) + (1 / len(alphas_cumprod)) - - ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + ns = SigmaConvert() if image is not None: img = image * ns.marginal_alpha(timesteps[0]) @@ -859,16 +871,10 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex else: img = noise - if to_zero: - timesteps[-1] = (1 / len(alphas_cumprod)) - - device = noise.device - - model_type = "noise" model_fn = model_wrapper( - model.predict_eps_discrete_timestep, + model.predict_eps_sigma, ns, model_type=model_type, guidance_type="uncond", @@ -878,6 +884,5 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex order = min(3, len(timesteps) - 1) uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) - if not to_zero: - x /= ns.marginal_alpha(timesteps[-1]) + x /= ns.marginal_alpha(timesteps[-1]) return x diff --git a/comfy/k_diffusion/external.py b/comfy/k_diffusion/external.py index c1a137d9c..953d3db2c 100644 --- a/comfy/k_diffusion/external.py +++ b/comfy/k_diffusion/external.py @@ -97,6 +97,10 @@ class DiscreteSchedule(nn.Module): input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5) return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim) + def predict_eps_sigma(self, input, sigma, **kwargs): + input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5) + return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim) + class DiscreteEpsDDPMDenoiser(DiscreteSchedule): """A wrapper for discrete schedule DDPM models that output eps (the predicted noise).""" diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index 1fb7ed879..d2f1d74a9 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -2,67 +2,66 @@ import torch # import pytorch_lightning as pl import torch.nn.functional as F from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Tuple, Union -from comfy.ldm.modules.diffusionmodules.model import Encoder, Decoder from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution from comfy.ldm.util import instantiate_from_config from comfy.ldm.modules.ema import LitEma -# class AutoencoderKL(pl.LightningModule): -class AutoencoderKL(torch.nn.Module): - def __init__(self, - ddconfig, - lossconfig, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - ema_decay=None, - learn_logvar=False - ): +class DiagonalGaussianRegularizer(torch.nn.Module): + def __init__(self, sample: bool = True): super().__init__() - self.learn_logvar = learn_logvar - self.image_key = image_key - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - self.loss = instantiate_from_config(lossconfig) - assert ddconfig["double_z"] - self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) - self.embed_dim = embed_dim - if colorize_nlabels is not None: - assert type(colorize_nlabels)==int - self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + self.sample = sample + + def get_trainable_parameters(self) -> Any: + yield from () + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + log = dict() + posterior = DiagonalGaussianDistribution(z) + if self.sample: + z = posterior.sample() + else: + z = posterior.mode() + kl_loss = posterior.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + log["kl_loss"] = kl_loss + return z, log + + +class AbstractAutoencoder(torch.nn.Module): + """ + This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators, + unCLIP models, etc. Hence, it is fairly general, and specific features + (e.g. discriminator training, encoding, decoding) must be implemented in subclasses. + """ + + def __init__( + self, + ema_decay: Union[None, float] = None, + monitor: Union[None, str] = None, + input_key: str = "jpg", + **kwargs, + ): + super().__init__() + + self.input_key = input_key + self.use_ema = ema_decay is not None if monitor is not None: self.monitor = monitor - self.use_ema = ema_decay is not None if self.use_ema: - self.ema_decay = ema_decay - assert 0. < ema_decay < 1. self.model_ema = LitEma(self, decay=ema_decay) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + def get_input(self, batch) -> Any: + raise NotImplementedError() - def init_from_ckpt(self, path, ignore_keys=list()): - if path.lower().endswith(".safetensors"): - import safetensors.torch - sd = safetensors.torch.load_file(path, device="cpu") - else: - sd = torch.load(path, map_location="cpu")["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - self.load_state_dict(sd, strict=False) - print(f"Restored from {path}") + def on_train_batch_end(self, *args, **kwargs): + # for EMA computation + if self.use_ema: + self.model_ema(self) @contextmanager def ema_scope(self, context=None): @@ -70,154 +69,159 @@ class AutoencoderKL(torch.nn.Module): self.model_ema.store(self.parameters()) self.model_ema.copy_to(self) if context is not None: - print(f"{context}: Switched to EMA weights") + logpy.info(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.parameters()) if context is not None: - print(f"{context}: Restored training weights") + logpy.info(f"{context}: Restored training weights") - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.model_ema(self) + def encode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("encode()-method of abstract base class called") - def encode(self, x): - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior + def decode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("decode()-method of abstract base class called") - def decode(self, z): - z = self.post_quant_conv(z) - dec = self.decoder(z) - return dec + def instantiate_optimizer_from_config(self, params, lr, cfg): + logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") + return get_obj_from_str(cfg["target"])( + params, lr=lr, **cfg.get("params", dict()) + ) - def forward(self, input, sample_posterior=True): - posterior = self.encode(input) - if sample_posterior: - z = posterior.sample() - else: - z = posterior.mode() - dec = self.decode(z) - return dec, posterior + def configure_optimizers(self) -> Any: + raise NotImplementedError() - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() - return x - def training_step(self, batch, batch_idx, optimizer_idx): - inputs = self.get_input(batch, self.image_key) - reconstructions, posterior = self(inputs) +class AutoencodingEngine(AbstractAutoencoder): + """ + Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL + (we also restore them explicitly as special cases for legacy reasons). + Regularizations such as KL or VQ are moved to the regularizer class. + """ - if optimizer_idx == 0: - # train encoder+decoder+logvar - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) - return aeloss + def __init__( + self, + *args, + encoder_config: Dict, + decoder_config: Dict, + regularizer_config: Dict, + **kwargs, + ): + super().__init__(*args, **kwargs) - if optimizer_idx == 1: - # train the discriminator - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - - self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) - return discloss - - def validation_step(self, batch, batch_idx): - log_dict = self._validation_step(batch, batch_idx) - with self.ema_scope(): - log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") - return log_dict - - def _validation_step(self, batch, batch_idx, postfix=""): - inputs = self.get_input(batch, self.image_key) - reconstructions, posterior = self(inputs) - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) - - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) - - self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) - self.log_dict(log_dict_ae) - self.log_dict(log_dict_disc) - return self.log_dict - - def configure_optimizers(self): - lr = self.learning_rate - ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( - self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) - if self.learn_logvar: - print(f"{self.__class__.__name__}: Learning logvar") - ae_params_list.append(self.loss.logvar) - opt_ae = torch.optim.Adam(ae_params_list, - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr, betas=(0.5, 0.9)) - return [opt_ae, opt_disc], [] + self.encoder: torch.nn.Module = instantiate_from_config(encoder_config) + self.decoder: torch.nn.Module = instantiate_from_config(decoder_config) + self.regularization: AbstractRegularizer = instantiate_from_config( + regularizer_config + ) def get_last_layer(self): - return self.decoder.conv_out.weight + return self.decoder.get_last_layer() - @torch.no_grad() - def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): - log = dict() - x = self.get_input(batch, self.image_key) - x = x.to(self.device) - if not only_inputs: - xrec, posterior = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec.shape[1] > 3 - x = self.to_rgb(x) - xrec = self.to_rgb(xrec) - log["samples"] = self.decode(torch.randn_like(posterior.sample())) - log["reconstructions"] = xrec - if log_ema or self.use_ema: - with self.ema_scope(): - xrec_ema, posterior_ema = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec_ema.shape[1] > 3 - xrec_ema = self.to_rgb(xrec_ema) - log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) - log["reconstructions_ema"] = xrec_ema - log["inputs"] = x - return log + def encode( + self, + x: torch.Tensor, + return_reg_log: bool = False, + unregularized: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + z = self.encoder(x) + if unregularized: + return z, dict() + z, reg_log = self.regularization(z) + if return_reg_log: + return z, reg_log + return z - def to_rgb(self, x): - assert self.image_key == "segmentation" - if not hasattr(self, "colorize"): - self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) - x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: + x = self.decoder(z, **kwargs) return x + def forward( + self, x: torch.Tensor, **additional_decode_kwargs + ) -> Tuple[torch.Tensor, torch.Tensor, dict]: + z, reg_log = self.encode(x, return_reg_log=True) + dec = self.decode(z, **additional_decode_kwargs) + return z, dec, reg_log -class IdentityFirstStage(torch.nn.Module): - def __init__(self, *args, vq_interface=False, **kwargs): - self.vq_interface = vq_interface - super().__init__() - def encode(self, x, *args, **kwargs): - return x +class AutoencodingEngineLegacy(AutoencodingEngine): + def __init__(self, embed_dim: int, **kwargs): + self.max_batch_size = kwargs.pop("max_batch_size", None) + ddconfig = kwargs.pop("ddconfig") + super().__init__( + encoder_config={ + "target": "comfy.ldm.modules.diffusionmodules.model.Encoder", + "params": ddconfig, + }, + decoder_config={ + "target": "comfy.ldm.modules.diffusionmodules.model.Decoder", + "params": ddconfig, + }, + **kwargs, + ) + self.quant_conv = torch.nn.Conv2d( + (1 + ddconfig["double_z"]) * ddconfig["z_channels"], + (1 + ddconfig["double_z"]) * embed_dim, + 1, + ) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim - def decode(self, x, *args, **kwargs): - return x + def get_autoencoder_params(self) -> list: + params = super().get_autoencoder_params() + return params - def quantize(self, x, *args, **kwargs): - if self.vq_interface: - return x, None, [None, None, None] - return x + def encode( + self, x: torch.Tensor, return_reg_log: bool = False + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + if self.max_batch_size is None: + z = self.encoder(x) + z = self.quant_conv(z) + else: + N = x.shape[0] + bs = self.max_batch_size + n_batches = int(math.ceil(N / bs)) + z = list() + for i_batch in range(n_batches): + z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs]) + z_batch = self.quant_conv(z_batch) + z.append(z_batch) + z = torch.cat(z, 0) - def forward(self, x, *args, **kwargs): - return x + z, reg_log = self.regularization(z) + if return_reg_log: + return z, reg_log + return z + def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: + if self.max_batch_size is None: + dec = self.post_quant_conv(z) + dec = self.decoder(dec, **decoder_kwargs) + else: + N = z.shape[0] + bs = self.max_batch_size + n_batches = int(math.ceil(N / bs)) + dec = list() + for i_batch in range(n_batches): + dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs]) + dec_batch = self.decoder(dec_batch, **decoder_kwargs) + dec.append(dec_batch) + dec = torch.cat(dec, 0) + + return dec + + +class AutoencoderKL(AutoencodingEngineLegacy): + def __init__(self, **kwargs): + if "lossconfig" in kwargs: + kwargs["loss_config"] = kwargs.pop("lossconfig") + super().__init__( + regularizer_config={ + "target": ( + "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer" + ) + }, + **kwargs, + ) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index ac0d9c8c5..9cd14a537 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -285,15 +285,14 @@ def attention_pytorch(q, k, v, heads, mask=None): ) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) - - if exists(mask): - raise NotImplementedError out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) ) return out + optimized_attention = attention_basic +optimized_attention_masked = attention_basic if model_management.xformers_enabled(): print("Using xformers cross attention") @@ -309,6 +308,9 @@ else: print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad +if model_management.pytorch_attention_enabled(): + optimized_attention_masked = attention_pytorch + class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): super().__init__() @@ -334,7 +336,10 @@ class CrossAttention(nn.Module): else: v = self.to_v(context) - out = optimized_attention(q, k, v, self.heads, mask) + if mask is None: + out = optimized_attention(q, k, v, self.heads) + else: + out = optimized_attention_masked(q, k, v, self.heads, mask) return self.to_out(out) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 6576df4b6..f23417fd2 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -193,6 +193,52 @@ def slice_attention(q, k, v): return r1 +def normal_attention(q, k, v): + # compute attention + b,c,h,w = q.shape + + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + v = v.reshape(b,c,h*w) + + r1 = slice_attention(q, k, v) + h_ = r1.reshape(b,c,h,w) + del r1 + return h_ + +def xformers_attention(q, k, v): + # compute attention + B, C, H, W = q.shape + q, k, v = map( + lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(), + (q, k, v), + ) + + try: + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + out = out.transpose(1, 2).reshape(B, C, H, W) + except NotImplementedError as e: + out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) + return out + +def pytorch_attention(q, k, v): + # compute attention + B, C, H, W = q.shape + q, k, v = map( + lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), + (q, k, v), + ) + + try: + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = out.transpose(2, 3).reshape(B, C, H, W) + except model_management.OOM_EXCEPTION as e: + print("scaled_dot_product_attention OOMed: switched to slice attention") + out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) + return out + + class AttnBlock(nn.Module): def __init__(self, in_channels): super().__init__() @@ -220,6 +266,16 @@ class AttnBlock(nn.Module): stride=1, padding=0) + if model_management.xformers_enabled_vae(): + print("Using xformers attention in VAE") + self.optimized_attention = xformers_attention + elif model_management.pytorch_attention_enabled(): + print("Using pytorch attention in VAE") + self.optimized_attention = pytorch_attention + else: + print("Using split attention in VAE") + self.optimized_attention = normal_attention + def forward(self, x): h_ = x h_ = self.norm(h_) @@ -227,149 +283,15 @@ class AttnBlock(nn.Module): k = self.k(h_) v = self.v(h_) - # compute attention - b,c,h,w = q.shape + h_ = self.optimized_attention(q, k, v) - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - v = v.reshape(b,c,h*w) - - r1 = slice_attention(q, k, v) - h_ = r1.reshape(b,c,h,w) - del r1 h_ = self.proj_out(h_) return x+h_ -class MemoryEfficientAttnBlock(nn.Module): - """ - Uses xformers efficient implementation, - see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - Note: this is a single-head self-attention operation - """ - # - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = comfy.ops.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = comfy.ops.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = comfy.ops.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = comfy.ops.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.attention_op: Optional[Any] = None - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - B, C, H, W = q.shape - q, k, v = map( - lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(), - (q, k, v), - ) - - try: - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) - out = out.transpose(1, 2).reshape(B, C, H, W) - except NotImplementedError as e: - out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) - - out = self.proj_out(out) - return x+out - -class MemoryEfficientAttnBlockPytorch(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = comfy.ops.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = comfy.ops.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = comfy.ops.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = comfy.ops.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.attention_op: Optional[Any] = None - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - B, C, H, W = q.shape - q, k, v = map( - lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), - (q, k, v), - ) - - try: - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) - out = out.transpose(2, 3).reshape(B, C, H, W) - except model_management.OOM_EXCEPTION as e: - print("scaled_dot_product_attention OOMed: switched to slice attention") - out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) - - out = self.proj_out(out) - return x+out def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): - assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' - if model_management.xformers_enabled_vae() and attn_type == "vanilla": - attn_type = "vanilla-xformers" - elif model_management.pytorch_attention_enabled() and attn_type == "vanilla": - attn_type = "vanilla-pytorch" - print(f"making attention of type '{attn_type}' with {in_channels} in_channels") - if attn_type == "vanilla": - assert attn_kwargs is None - return AttnBlock(in_channels) - elif attn_type == "vanilla-xformers": - print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") - return MemoryEfficientAttnBlock(in_channels) - elif attn_type == "vanilla-pytorch": - return MemoryEfficientAttnBlockPytorch(in_channels) - elif attn_type == "none": - return nn.Identity(in_channels) - else: - raise NotImplementedError() + return AttnBlock(in_channels) class Model(nn.Module): @@ -619,7 +541,10 @@ class Decoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, - attn_type="vanilla", **ignorekwargs): + conv_out_op=comfy.ops.Conv2d, + resnet_op=ResnetBlock, + attn_op=AttnBlock, + **ignorekwargs): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch @@ -648,12 +573,12 @@ class Decoder(nn.Module): # middle self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, + self.mid.block_1 = resnet_op(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, + self.mid.attn_1 = attn_op(block_in) + self.mid.block_2 = resnet_op(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) @@ -665,13 +590,13 @@ class Decoder(nn.Module): attn = nn.ModuleList() block_out = ch*ch_mult[i_level] for i_block in range(self.num_res_blocks+1): - block.append(ResnetBlock(in_channels=block_in, + block.append(resnet_op(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) + attn.append(attn_op(block_in)) up = nn.Module() up.block = block up.attn = attn @@ -682,13 +607,13 @@ class Decoder(nn.Module): # end self.norm_out = Normalize(block_in) - self.conv_out = comfy.ops.Conv2d(block_in, + self.conv_out = conv_out_op(block_in, out_ch, kernel_size=3, stride=1, padding=1) - def forward(self, z): + def forward(self, z, **kwargs): #assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape @@ -699,16 +624,16 @@ class Decoder(nn.Module): h = self.conv_in(z) # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) + h = self.mid.block_1(h, temb, **kwargs) + h = self.mid.attn_1(h, **kwargs) + h = self.mid.block_2(h, temb, **kwargs) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block](h, temb) + h = self.up[i_level].block[i_block](h, temb, **kwargs) if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) + h = self.up[i_level].attn[i_block](h, **kwargs) if i_level != 0: h = self.up[i_level].upsample(h) @@ -718,7 +643,7 @@ class Decoder(nn.Module): h = self.norm_out(h) h = nonlinearity(h) - h = self.conv_out(h) + h = self.conv_out(h, **kwargs) if self.tanh_out: h = torch.tanh(h) return h diff --git a/comfy/model_base.py b/comfy/model_base.py index ed2dc83e4..cda6765e4 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -26,6 +26,7 @@ class BaseModel(torch.nn.Module): self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 + self.inpaint_model = False print("model_type", model_type.name) print("adm", self.adm_channels) @@ -71,6 +72,38 @@ class BaseModel(torch.nn.Module): def encode_adm(self, **kwargs): return None + def cond_concat(self, **kwargs): + if self.inpaint_model: + concat_keys = ("mask", "masked_image") + cond_concat = [] + denoise_mask = kwargs.get("denoise_mask", None) + latent_image = kwargs.get("latent_image", None) + noise = kwargs.get("noise", None) + device = kwargs["device"] + + def blank_inpaint_image_like(latent_image): + blank_image = torch.ones_like(latent_image) + # these are the values for "zero" in pixel space translated to latent space + blank_image[:,0] *= 0.8223 + blank_image[:,1] *= -0.6876 + blank_image[:,2] *= 0.6364 + blank_image[:,3] *= 0.1380 + return blank_image + + for ck in concat_keys: + if denoise_mask is not None: + if ck == "mask": + cond_concat.append(denoise_mask[:,:1].to(device)) + elif ck == "masked_image": + cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space + else: + if ck == "mask": + cond_concat.append(torch.ones_like(noise)[:,:1]) + elif ck == "masked_image": + cond_concat.append(blank_inpaint_image_like(noise)) + return cond_concat + return None + def load_model_weights(self, sd, unet_prefix=""): to_load = {} keys = list(sd.keys()) @@ -112,7 +145,7 @@ class BaseModel(torch.nn.Module): return {**unet_state_dict, **vae_state_dict, **clip_state_dict} def set_inpaint(self): - self.concat_keys = ("mask", "masked_image") + self.inpaint_model = True def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): adm_inputs = [] diff --git a/comfy/model_management.py b/comfy/model_management.py index c24c7b27e..64ed19727 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -667,7 +667,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): return False #FP16 is just broken on these cards - nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX"] + nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"] for x in nvidia_16_series: if x in props.name: return False diff --git a/comfy/sample.py b/comfy/sample.py index 322272766..e6a69973d 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -98,6 +98,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative samples = samples.cpu() cleanup_additional_models(models) + cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))) return samples def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): @@ -109,5 +110,6 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = samples.cpu() cleanup_additional_models(models) + cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))) return samples diff --git a/comfy/samplers.py b/comfy/samplers.py index e43f7a6fe..0b38fbd1e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -14,8 +14,8 @@ def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) #The main sampling function shared by all the samplers #Returns predicted noise -def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}, seed=None): - def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): +def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): + def get_area_and_mult(cond, x_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 if 'timestep_start' in cond[1]: @@ -68,12 +68,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con conditionning = {} conditionning['c_crossattn'] = cond[0] - if cond_concat_in is not None and len(cond_concat_in) > 0: - cropped = [] - for x in cond_concat_in: - cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - cropped.append(cr) - conditionning['c_concat'] = torch.cat(cropped, dim=1) + + if 'concat' in cond[1]: + cond_concat_in = cond[1]['concat'] + if cond_concat_in is not None and len(cond_concat_in) > 0: + cropped = [] + for x in cond_concat_in: + cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + cropped.append(cr) + conditionning['c_concat'] = torch.cat(cropped, dim=1) if adm_cond is not None: conditionning['c_adm'] = adm_cond @@ -173,7 +176,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con out['c_adm'] = torch.cat(c_adm) return out - def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options): + def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options): out_cond = torch.zeros_like(x_in) out_count = torch.ones_like(x_in)/100000.0 @@ -185,14 +188,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con to_run = [] for x in cond: - p = get_area_and_mult(x, x_in, cond_concat_in, timestep) + p = get_area_and_mult(x, x_in, timestep) if p is None: continue to_run += [(p, COND)] if uncond is not None: for x in uncond: - p = get_area_and_mult(x, x_in, cond_concat_in, timestep) + p = get_area_and_mult(x, x_in, timestep) if p is None: continue @@ -286,7 +289,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if math.isclose(cond_scale, 1.0): uncond = None - cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options) + cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options) if "sampler_cfg_function" in model_options: args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep} return model_options["sampler_cfg_function"](args) @@ -307,8 +310,8 @@ class CFGNoisePredictor(torch.nn.Module): super().__init__() self.inner_model = model self.alphas_cumprod = model.alphas_cumprod - def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, seed=None): - out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options, seed=seed) + def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None): + out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed) return out @@ -316,11 +319,11 @@ class KSamplerX0Inpaint(torch.nn.Module): def __init__(self, model): super().__init__() self.inner_model = model - def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}, seed=None): + def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): if denoise_mask is not None: latent_mask = 1. - denoise_mask x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask - out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options, seed=seed) + out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed) if denoise_mask is not None: out *= denoise_mask @@ -358,15 +361,6 @@ def sgm_scheduler(model, steps): sigs += [0.0] return torch.FloatTensor(sigs) -def blank_inpaint_image_like(latent_image): - blank_image = torch.ones_like(latent_image) - # these are the values for "zero" in pixel space translated to latent space - blank_image[:,0] *= 0.8223 - blank_image[:,1] *= -0.6876 - blank_image[:,2] *= 0.6364 - blank_image[:,3] *= 0.1380 - return blank_image - def get_mask_aabb(masks): if masks.numel() == 0: return torch.zeros((0, 4), device=masks.device, dtype=torch.int) @@ -543,6 +537,20 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type): return conds +def encode_cond(model_function, key, conds, device, **kwargs): + for t in range(len(conds)): + x = conds[t] + params = x[1].copy() + params["device"] = device + for k in kwargs: + if k not in params: + params[k] = kwargs[k] + + out = model_function(**params) + if out is not None: + x[1] = x[1].copy() + x[1][key] = out + return conds class Sampler: def sample(self): @@ -662,31 +670,19 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) + if latent_image is not None: + latent_image = model.process_latent_in(latent_image) + if model.is_adm(): positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive") negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative") - if latent_image is not None: - latent_image = model.process_latent_in(latent_image) + if hasattr(model, 'cond_concat'): + positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) + negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} - cond_concat = None - if hasattr(model, 'concat_keys'): #inpaint - cond_concat = [] - for ck in model.concat_keys: - if denoise_mask is not None: - if ck == "mask": - cond_concat.append(denoise_mask[:,:1]) - elif ck == "masked_image": - cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space - else: - if ck == "mask": - cond_concat.append(torch.ones_like(noise)[:,:1]) - elif ck == "masked_image": - cond_concat.append(blank_inpaint_image_like(noise)) - extra_args["cond_concat"] = cond_concat - samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) return model.process_latent_out(samples.to(torch.float32)) @@ -743,7 +739,7 @@ class KSampler: sigmas = None discard_penultimate_sigma = False - if self.sampler in ['dpm_2', 'dpm_2_ancestral']: + if self.sampler in ['dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2']: steps += 1 discard_penultimate_sigma = True diff --git a/comfy/sd.py b/comfy/sd.py index fd8b94df8..c364b723c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -4,7 +4,7 @@ import math from comfy import model_management from .ldm.util import instantiate_from_config -from .ldm.models.autoencoder import AutoencoderKL +from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine import yaml import comfy.utils @@ -140,21 +140,24 @@ class CLIP: return self.patcher.get_key_patches() class VAE: - def __init__(self, ckpt_path=None, device=None, config=None): + def __init__(self, sd=None, device=None, config=None): + if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format + sd = diffusers_convert.convert_vae_state_dict(sd) + if config is None: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss") + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) else: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() - if ckpt_path is not None: - sd = comfy.utils.load_torch_file(ckpt_path) - if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format - sd = diffusers_convert.convert_vae_state_dict(sd) - m, u = self.first_stage_model.load_state_dict(sd, strict=False) - if len(m) > 0: - print("Missing VAE keys", m) + + m, u = self.first_stage_model.load_state_dict(sd, strict=False) + if len(m) > 0: + print("Missing VAE keys", m) + + if len(u) > 0: + print("Leftover VAE keys", u) if device is None: device = model_management.vae_device() @@ -183,7 +186,7 @@ class VAE: steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = comfy.utils.ProgressBar(steps) - encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).sample().float() + encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) @@ -229,7 +232,7 @@ class VAE: samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") for x in range(0, pixel_samples.shape[0], batch_number): pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) - samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu().float() + samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float() except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") @@ -375,10 +378,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl model.load_model_weights(state_dict, "model.diffusion_model.") if output_vae: - w = WeightsLoader() - vae = VAE(config=vae_config) - w.first_stage_model = vae.first_stage_model - load_model_weights(w, state_dict) + vae_sd = comfy.utils.state_dict_prefix_replace(state_dict, {"first_stage_model.": ""}, filter_keys=True) + vae = VAE(sd=vae_sd, config=vae_config) if output_clip: w = WeightsLoader() @@ -427,18 +428,17 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o model.load_model_weights(sd, "model.diffusion_model.") if output_vae: - vae = VAE() - w = WeightsLoader() - w.first_stage_model = vae.first_stage_model - load_model_weights(w, sd) + vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) + vae = VAE(sd=vae_sd) if output_clip: w = WeightsLoader() clip_target = model_config.clip_target() - clip = CLIP(clip_target, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model - sd = model_config.process_clip_state_dict(sd) - load_model_weights(w, sd) + if clip_target is not None: + clip = CLIP(clip_target, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model + sd = model_config.process_clip_state_dict(sd) + load_model_weights(w, sd) left_over = sd.keys() if len(left_over) > 0: diff --git a/comfy/utils.py b/comfy/utils.py index df016ef9e..a1807aa1d 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -47,12 +47,17 @@ def state_dict_key_replace(state_dict, keys_to_replace): state_dict[keys_to_replace[x]] = state_dict.pop(x) return state_dict -def state_dict_prefix_replace(state_dict, replace_prefix): +def state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=False): + if filter_keys: + out = {} + else: + out = state_dict for rp in replace_prefix: replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) for x in replace: - state_dict[x[1]] = state_dict.pop(x[0]) - return state_dict + w = state_dict.pop(x[0]) + out[x[1]] = w + return out def transformers_convert(sd, prefix_from, prefix_to, number): diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py index 07a88bd96..7512b841d 100644 --- a/comfy_extras/nodes_freelunch.py +++ b/comfy_extras/nodes_freelunch.py @@ -61,7 +61,53 @@ class FreeU: m.set_model_output_block_patch(output_block_patch) return (m, ) +class FreeU_V2: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}), + "b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}), + "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}), + "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, b1, b2, s1, s2): + model_channels = model.model.model_config.unet_config["model_channels"] + scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} + on_cpu_devices = {} + + def output_block_patch(h, hsp, transformer_options): + scale = scale_dict.get(h.shape[1], None) + if scale is not None: + hidden_mean = h.mean(1).unsqueeze(1) + B = hidden_mean.shape[0] + hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) + hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True) + hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) + + h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * ((scale[0] - 1 ) * hidden_mean + 1) + + if hsp.device not in on_cpu_devices: + try: + hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) + except: + print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.") + on_cpu_devices[hsp.device] = True + hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) + else: + hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) + + return h, hsp + + m = model.clone() + m.set_model_output_block_patch(output_block_patch) + return (m, ) NODE_CLASS_MAPPINGS = { "FreeU": FreeU, + "FreeU_V2": FreeU_V2, } diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index d16c49aeb..f692945a8 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -19,6 +19,7 @@ def load_hypernetwork_patch(path, strength): "tanh": torch.nn.Tanh, "sigmoid": torch.nn.Sigmoid, "softsign": torch.nn.Softsign, + "mish": torch.nn.Mish, } if activation_func not in valid_activation: @@ -42,7 +43,8 @@ def load_hypernetwork_patch(path, strength): linears = list(map(lambda a: a[:-len(".weight")], linears)) layers = [] - for i in range(len(linears)): + i = 0 + while i < len(linears): lin_name = linears[i] last_layer = (i == (len(linears) - 1)) penultimate_layer = (i == (len(linears) - 2)) @@ -56,10 +58,17 @@ def load_hypernetwork_patch(path, strength): if (not last_layer) or (activate_output): layers.append(valid_activation[activation_func]()) if is_layer_norm: - layers.append(torch.nn.LayerNorm(lin_weight.shape[0])) + i += 1 + ln_name = linears[i] + ln_weight = attn_weights['{}.weight'.format(ln_name)] + ln_bias = attn_weights['{}.bias'.format(ln_name)] + ln = torch.nn.LayerNorm(ln_weight.shape[0]) + ln.load_state_dict({"weight": ln_weight, "bias": ln_bias}) + layers.append(ln) if use_dropout: if (not last_layer) and (not penultimate_layer or last_layer_dropout): layers.append(torch.nn.Dropout(p=0.3)) + i += 1 output.append(torch.nn.Sequential(*layers)) out[dim] = torch.nn.ModuleList(output) diff --git a/comfy_extras/nodes_hypertile.py b/comfy_extras/nodes_hypertile.py new file mode 100644 index 000000000..0d7d4c954 --- /dev/null +++ b/comfy_extras/nodes_hypertile.py @@ -0,0 +1,83 @@ +#Taken from: https://github.com/tfernd/HyperTile/ + +import math +from einops import rearrange +import random + +def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int: + min_value = min(min_value, value) + + # All big divisors of value (inclusive) + divisors = [i for i in range(min_value, value + 1) if value % i == 0] + + ns = [value // i for i in divisors[:max_options]] # has at least 1 element + + random.seed(counter) + idx = random.randint(0, len(ns) - 1) + + return ns[idx] + +class HyperTile: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}), + "swap_size": ("INT", {"default": 2, "min": 1, "max": 128}), + "max_depth": ("INT", {"default": 0, "min": 0, "max": 10}), + "scale_depth": ("BOOLEAN", {"default": False}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, tile_size, swap_size, max_depth, scale_depth): + model_channels = model.model.model_config.unet_config["model_channels"] + + apply_to = set() + temp = model_channels + for x in range(max_depth + 1): + apply_to.add(temp) + temp *= 2 + + latent_tile_size = max(32, tile_size) // 8 + self.temp = None + self.counter = 1 + + def hypertile_in(q, k, v, extra_options): + if q.shape[-1] in apply_to: + shape = extra_options["original_shape"] + aspect_ratio = shape[-1] / shape[-2] + + hw = q.size(1) + h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) + + factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1 + nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter) + self.counter += 1 + nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter) + self.counter += 1 + + if nh * nw > 1: + q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) + self.temp = (nh, nw, h, w) + return q, k, v + + return q, k, v + def hypertile_out(out, extra_options): + if self.temp is not None: + nh, nw, h, w = self.temp + self.temp = None + out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) + out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) + return out + + + m = model.clone() + m.set_model_attn1_patch(hypertile_in) + m.set_model_attn1_output_patch(hypertile_out) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "HyperTile": HyperTile, +} diff --git a/nodes.py b/nodes.py index 208cbc84f..61ebbb8b4 100644 --- a/nodes.py +++ b/nodes.py @@ -584,7 +584,8 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): vae_path = folder_paths.get_full_path("vae", vae_name) - vae = comfy.sd.VAE(ckpt_path=vae_path) + sd = comfy.utils.load_torch_file(vae_path) + vae = comfy.sd.VAE(sd=sd) return (vae,) class ControlNetLoader: @@ -1660,7 +1661,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "KSampler": "KSampler", "KSamplerAdvanced": "KSampler (Advanced)", # Loaders - "CheckpointLoader": "Load Checkpoint (With Config)", + "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)", "CheckpointLoaderSimple": "Load Checkpoint", "VAELoader": "Load VAE", "LoraLoader": "Load LoRA", @@ -1795,7 +1796,8 @@ def init_custom_nodes(): "nodes_clip_sdxl.py", "nodes_canny.py", "nodes_freelunch.py", - "nodes_custom_sampler.py" + "nodes_custom_sampler.py", + "nodes_hypertile.py", ] for node_file in extras_files: diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index 4fdccaace..ec83265b4 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -47,7 +47,7 @@ " !git pull\n", "\n", "!echo -= Install dependencies =-\n", - "!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117" + "!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121 --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117" ] }, { diff --git a/tests-ui/package-lock.json b/tests-ui/package-lock.json index 75b32a4cf..35911cd7f 100644 --- a/tests-ui/package-lock.json +++ b/tests-ui/package-lock.json @@ -1816,9 +1816,9 @@ } }, "node_modules/@babel/traverse": { - "version": "7.23.0", - "resolved": "https://registry.npmjs.org/@babel/traverse/-/traverse-7.23.0.tgz", - "integrity": "sha512-t/QaEvyIoIkwzpiZ7aoSKK8kObQYeF7T2v+dazAYCb8SXtp58zEVkWW7zAnju8FNKNdr4ScAOEDmMItbyOmEYw==", + "version": "7.23.2", + "resolved": "https://registry.npmjs.org/@babel/traverse/-/traverse-7.23.2.tgz", + "integrity": "sha512-azpe59SQ48qG6nu2CzcMLbxUudtN+dOM9kDbUqGq3HXUJRlo7i8fvPoxQUzYgLZ4cMVmuZgm8vvBpNeRhd6XSw==", "dev": true, "dependencies": { "@babel/code-frame": "^7.22.13", diff --git a/tests-ui/package.json b/tests-ui/package.json index cb8a23ec2..e7b60ad8e 100644 --- a/tests-ui/package.json +++ b/tests-ui/package.json @@ -4,7 +4,8 @@ "description": "UI tests", "main": "index.js", "scripts": { - "test": "jest" + "test": "jest", + "test:generate": "node setup.js" }, "repository": { "type": "git", diff --git a/tests-ui/setup.js b/tests-ui/setup.js new file mode 100644 index 000000000..0f368ab22 --- /dev/null +++ b/tests-ui/setup.js @@ -0,0 +1,87 @@ +const { spawn } = require("child_process"); +const { resolve } = require("path"); +const { existsSync, mkdirSync, writeFileSync } = require("fs"); +const http = require("http"); + +async function setup() { + // Wait up to 30s for it to start + let success = false; + let child; + for (let i = 0; i < 30; i++) { + try { + await new Promise((res, rej) => { + http + .get("http://127.0.0.1:8188/object_info", (resp) => { + let data = ""; + resp.on("data", (chunk) => { + data += chunk; + }); + resp.on("end", () => { + // Modify the response data to add some checkpoints + const objectInfo = JSON.parse(data); + objectInfo.CheckpointLoaderSimple.input.required.ckpt_name[0] = ["model1.safetensors", "model2.ckpt"]; + + data = JSON.stringify(objectInfo, undefined, "\t"); + + const outDir = resolve("./data"); + if (!existsSync(outDir)) { + mkdirSync(outDir); + } + + const outPath = resolve(outDir, "object_info.json"); + console.log(`Writing ${Object.keys(objectInfo).length} nodes to ${outPath}`); + writeFileSync(outPath, data, { + encoding: "utf8", + }); + res(); + }); + }) + .on("error", rej); + }); + success = true; + break; + } catch (error) { + console.log(i + "/30", error); + if (i === 0) { + // Start the server on first iteration if it fails to connect + console.log("Starting ComfyUI server..."); + + let python = resolve("../../python_embeded/python.exe"); + let args; + let cwd; + if (existsSync(python)) { + args = ["-s", "ComfyUI/main.py"]; + cwd = "../.."; + } else { + python = "python"; + args = ["main.py"]; + cwd = ".."; + } + args.push("--cpu"); + console.log(python, ...args); + child = spawn(python, args, { cwd }); + child.on("error", (err) => { + console.log(`Server error (${err})`); + i = 30; + }); + child.on("exit", (code) => { + if (!success) { + console.log(`Server exited (${code})`); + i = 30; + } + }); + } + await new Promise((r) => { + setTimeout(r, 1000); + }); + } + } + + child?.kill(); + + if (!success) { + throw new Error("Waiting for server failed..."); + } +} + + setup(); \ No newline at end of file diff --git a/web/extensions/core/groupOptions.js b/web/extensions/core/groupOptions.js index d523737dc..5dd21e730 100644 --- a/web/extensions/core/groupOptions.js +++ b/web/extensions/core/groupOptions.js @@ -5,6 +5,61 @@ function setNodeMode(node, mode) { node.graph.change(); } +function addNodesToGroup(group, nodes=[]) { + var x1, y1, x2, y2; + var nx1, ny1, nx2, ny2; + var node; + + x1 = y1 = x2 = y2 = -1; + nx1 = ny1 = nx2 = ny2 = -1; + + for (var n of [group._nodes, nodes]) { + for (var i in n) { + node = n[i] + + nx1 = node.pos[0] + ny1 = node.pos[1] + nx2 = node.pos[0] + node.size[0] + ny2 = node.pos[1] + node.size[1] + + if (node.type != "Reroute") { + ny1 -= LiteGraph.NODE_TITLE_HEIGHT; + } + + if (node.flags?.collapsed) { + ny2 = ny1 + LiteGraph.NODE_TITLE_HEIGHT; + + if (node?._collapsed_width) { + nx2 = nx1 + Math.round(node._collapsed_width); + } + } + + if (x1 == -1 || nx1 < x1) { + x1 = nx1; + } + + if (y1 == -1 || ny1 < y1) { + y1 = ny1; + } + + if (x2 == -1 || nx2 > x2) { + x2 = nx2; + } + + if (y2 == -1 || ny2 > y2) { + y2 = ny2; + } + } + } + + var padding = 10; + + y1 = y1 - Math.round(group.font_size * 1.4); + + group.pos = [x1 - padding, y1 - padding]; + group.size = [x2 - x1 + padding * 2, y2 - y1 + padding * 2]; +} + app.registerExtension({ name: "Comfy.GroupOptions", setup() { @@ -14,6 +69,17 @@ app.registerExtension({ const options = orig.apply(this, arguments); const group = this.graph.getGroupOnPos(this.graph_mouse[0], this.graph_mouse[1]); if (!group) { + options.push({ + content: "Add Group For Selected Nodes", + disabled: !Object.keys(app.canvas.selected_nodes || {}).length, + callback: () => { + var group = new LiteGraph.LGraphGroup(); + addNodesToGroup(group, this.selected_nodes) + app.canvas.graph.add(group); + this.graph.change(); + } + }); + return options; } @@ -21,6 +87,15 @@ app.registerExtension({ group.recomputeInsideNodes(); const nodesInGroup = group._nodes; + options.push({ + content: "Add Selected Nodes To Group", + disabled: !Object.keys(app.canvas.selected_nodes || {}).length, + callback: () => { + addNodesToGroup(group, this.selected_nodes) + this.graph.change(); + } + }); + // No nodes in group, return default options if (nodesInGroup.length === 0) { return options; @@ -38,6 +113,14 @@ app.registerExtension({ } } + options.push({ + content: "Fit Group To Nodes", + callback: () => { + addNodesToGroup(group) + this.graph.change(); + } + }); + options.push({ content: "Select Nodes", callback: () => { diff --git a/web/extensions/core/nodeTemplates.js b/web/extensions/core/nodeTemplates.js index 7059f826d..434491075 100644 --- a/web/extensions/core/nodeTemplates.js +++ b/web/extensions/core/nodeTemplates.js @@ -22,6 +22,15 @@ class ManageTemplates extends ComfyDialog { super(); this.element.classList.add("comfy-manage-templates"); this.templates = this.load(); + + this.importInput = $el("input", { + type: "file", + accept: ".json", + multiple: true, + style: {display: "none"}, + parent: document.body, + onchange: () => this.importAll(), + }); } createButtons() { @@ -34,6 +43,22 @@ class ManageTemplates extends ComfyDialog { onclick: () => this.save(), }) ); + btns.unshift( + $el("button", { + type: "button", + textContent: "Export", + onclick: () => this.exportAll(), + }) + ); + btns.unshift( + $el("button", { + type: "button", + textContent: "Import", + onclick: () => { + this.importInput.click(); + }, + }) + ); return btns; } @@ -69,6 +94,52 @@ class ManageTemplates extends ComfyDialog { localStorage.setItem(id, JSON.stringify(this.templates)); } + async importAll() { + for (const file of this.importInput.files) { + if (file.type === "application/json" || file.name.endsWith(".json")) { + const reader = new FileReader(); + reader.onload = async () => { + var importFile = JSON.parse(reader.result); + if (importFile && importFile?.templates) { + for (const template of importFile.templates) { + if (template?.name && template?.data) { + this.templates.push(template); + } + } + this.store(); + } + }; + await reader.readAsText(file); + } + } + + this.importInput.value = null; + + this.close(); + } + + exportAll() { + if (this.templates.length == 0) { + alert("No templates to export."); + return; + } + + const json = JSON.stringify({templates: this.templates}, null, 2); // convert the data to a JSON string + const blob = new Blob([json], {type: "application/json"}); + const url = URL.createObjectURL(blob); + const a = $el("a", { + href: url, + download: "node_templates.json", + style: {display: "none"}, + parent: document.body, + }); + a.click(); + setTimeout(function () { + a.remove(); + window.URL.revokeObjectURL(url); + }, 0); + } + show() { // Show list of template names + delete button super.show( @@ -97,19 +168,48 @@ class ManageTemplates extends ComfyDialog { }), ] ), - $el("button", { - textContent: "Delete", - style: { - fontSize: "12px", - color: "red", - fontWeight: "normal", - }, - onclick: (e) => { - nameInput.value = ""; - e.target.style.display = "none"; - e.target.previousElementSibling.style.display = "none"; - }, - }), + $el( + "div", + {}, + [ + $el("button", { + textContent: "Export", + style: { + fontSize: "12px", + fontWeight: "normal", + }, + onclick: (e) => { + const json = JSON.stringify({templates: [t]}, null, 2); // convert the data to a JSON string + const blob = new Blob([json], {type: "application/json"}); + const url = URL.createObjectURL(blob); + const a = $el("a", { + href: url, + download: (nameInput.value || t.name) + ".json", + style: {display: "none"}, + parent: document.body, + }); + a.click(); + setTimeout(function () { + a.remove(); + window.URL.revokeObjectURL(url); + }, 0); + }, + }), + $el("button", { + textContent: "Delete", + style: { + fontSize: "12px", + color: "red", + fontWeight: "normal", + }, + onclick: (e) => { + nameInput.value = ""; + e.target.parentElement.style.display = "none"; + e.target.parentElement.previousElementSibling.style.display = "none"; + }, + }), + ] + ), ]; }) ) @@ -164,19 +264,17 @@ app.registerExtension({ }, })); - if (subItems.length) { - subItems.push(null, { - content: "Manage", - callback: () => manage.show(), - }); + subItems.push(null, { + content: "Manage", + callback: () => manage.show(), + }); - options.push({ - content: "Node Templates", - submenu: { - options: subItems, - }, - }); - } + options.push({ + content: "Node Templates", + submenu: { + options: subItems, + }, + }); return options; }; diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index f81c83a8a..e906590f5 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -3796,7 +3796,7 @@ out = out || new Float32Array(4); out[0] = this.pos[0] - 4; out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT; - out[2] = this.size[0] + 4; + out[2] = this.flags.collapsed ? (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) : this.size[0] + 4; out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT; if (this.onBounding) { diff --git a/web/scripts/app.js b/web/scripts/app.js index 7159da33c..55d4b120e 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -492,7 +492,7 @@ export class ComfyApp { } if (this.imgs && this.imgs.length) { - const canvas = graph.list_of_graphcanvas[0]; + const canvas = app.graph.list_of_graphcanvas[0]; const mouse = canvas.graph_mouse; if (!canvas.pointer_is_down && this.pointerDown) { if (mouse[0] === this.pointerDown.pos[0] && mouse[1] === this.pointerDown.pos[1]) { @@ -1430,6 +1430,43 @@ export class ComfyApp { } } + loadTemplateData(templateData) { + if (!templateData?.templates) { + return; + } + + const old = localStorage.getItem("litegrapheditor_clipboard"); + + var maxY, nodeBottom, node; + + for (const template of templateData.templates) { + if (!template?.data) { + continue; + } + + localStorage.setItem("litegrapheditor_clipboard", template.data); + app.canvas.pasteFromClipboard(); + + // Move mouse position down to paste the next template below + + maxY = false; + + for (const i in app.canvas.selected_nodes) { + node = app.canvas.selected_nodes[i]; + + nodeBottom = node.pos[1] + node.size[1]; + + if (maxY === false || nodeBottom > maxY) { + maxY = nodeBottom; + } + } + + app.canvas.graph_mouse[1] = maxY + 50; + } + + localStorage.setItem("litegrapheditor_clipboard", old); + } + /** * Populates the graph with the specified workflow data * @param {*} graphData A serialized graph object @@ -1775,7 +1812,12 @@ export class ComfyApp { } else if (file.type === "application/json" || file.name?.endsWith(".json")) { const reader = new FileReader(); reader.onload = async () => { - await this.loadGraphData(JSON.parse(reader.result)); + var jsonContent = JSON.parse(reader.result); + if (jsonContent?.templates) { + this.loadTemplateData(jsonContent); + } else { + await this.loadGraphData(jsonContent); + } }; reader.readAsText(file); } else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) {